1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
|
import collections
import logging
from oic.oic.message import Message, AuthorizationResponse, AuthorizationErrorResponse
from typing import Mapping, Union
logger = logging.getLogger(__name__)
AuthenticationResult = collections.namedtuple('AuthenticationResult',
['access_token', 'id_token_claims', 'id_token_jwt', 'userinfo_claims', 'refresh_token'])
class AuthResponseProcessError(ValueError):
pass
class AuthResponseUnexpectedStateError(AuthResponseProcessError):
pass
class AuthResponseUnexpectedNonceError(AuthResponseProcessError):
pass
class AuthResponseMismatchingSubjectError(AuthResponseProcessError):
pass
class AuthResponseErrorResponseError(AuthResponseProcessError):
def __init__(self, error_response: Mapping[str, str]):
"""
Args:
error_response (Mapping[str, str]): OAuth error response containing 'error' and 'error_description'
"""
self.error_response = error_response
class AuthResponseHandler:
def __init__(self, client):
"""
Args:
client (flask_pyoidc.pyoidc_facade.PyoidcFacade): Client proxy to make requests to the provider
"""
self._client = client
def process_auth_response(self, auth_response: Union[AuthorizationResponse, AuthorizationErrorResponse], expected_state: str, expected_nonce: str=None):
"""
Args:
auth_response (Union[AuthorizationResponse, AuthorizationErrorResponse]): parsed OIDC auth response
expected_state (str): state value included in the OIDC auth request
expected_nonce (str): nonce value included in the OIDC auth request
Returns:
AuthenticationResult: All relevant data associated with the authenticated user
"""
if isinstance(auth_response, AuthorizationErrorResponse):
raise AuthResponseErrorResponseError(auth_response.to_dict())
if auth_response['state'] != expected_state:
raise AuthResponseUnexpectedStateError()
# implicit/hybrid flow may return tokens in the auth response
access_token = auth_response.get('access_token', None)
refresh_token = auth_response.get('refresh_token', None)
id_token_claims = auth_response['id_token'].to_dict() if 'id_token' in auth_response else None
id_token_jwt = auth_response.get('id_token_jwt', None) if 'id_token_jwt' in auth_response else None
if 'code' in auth_response:
token_resp = self._client.token_request(auth_response['code'])
if token_resp:
if 'error' in token_resp:
raise AuthResponseErrorResponseError(token_resp.to_dict())
access_token = token_resp['access_token']
if 'id_token' in token_resp:
id_token = token_resp['id_token']
logger.debug('received id token: %s', id_token.to_json())
if id_token['nonce'] != expected_nonce:
raise AuthResponseUnexpectedNonceError()
id_token_claims = id_token.to_dict()
id_token_jwt = token_resp.get('id_token_jwt')
if 'refresh_token' in token_resp:
refresh_token = token_resp['refresh_token']
logger.debug('received refresh token: %s', refresh_token)
# do userinfo request
userinfo = self._client.userinfo_request(access_token)
userinfo_claims = None
if userinfo:
userinfo_claims = userinfo.to_dict()
if id_token_claims and userinfo_claims and userinfo_claims['sub'] != id_token_claims['sub']:
raise AuthResponseMismatchingSubjectError('The \'sub\' of userinfo does not match \'sub\' of ID Token.')
return AuthenticationResult(access_token, id_token_claims, id_token_jwt, userinfo_claims, refresh_token)
@classmethod
def expect_fragment_encoded_response(cls, auth_request):
if 'response_mode' in auth_request:
return auth_request['response_mode'] == 'fragment'
response_type = set(auth_request['response_type'].split(' '))
is_implicit_flow = response_type == {'id_token'} or \
response_type == {'id_token', 'token'}
is_hybrid_flow = response_type == {'code', 'id_token'} or \
response_type == {'code', 'token'} or \
response_type == {'code', 'id_token', 'token'}
return is_implicit_flow or is_hybrid_flow
|