services.AuthorizationResponseHandler

src/idserver/services/AuthorizationResponseHandler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import collections
import logging
from oic.oic.message import Message, AuthorizationResponse, AuthorizationErrorResponse
from typing import Mapping, Union


logger = logging.getLogger(__name__)

AuthenticationResult = collections.namedtuple('AuthenticationResult',
                                              ['access_token', 'id_token_claims', 'id_token_jwt', 'userinfo_claims', 'refresh_token'])


class AuthResponseProcessError(ValueError):
    pass


class AuthResponseUnexpectedStateError(AuthResponseProcessError):
    pass


class AuthResponseUnexpectedNonceError(AuthResponseProcessError):
    pass


class AuthResponseMismatchingSubjectError(AuthResponseProcessError):
    pass


class AuthResponseErrorResponseError(AuthResponseProcessError):
    def __init__(self, error_response: Mapping[str, str]):
        """
        Args:
            error_response (Mapping[str, str]): OAuth error response containing 'error' and 'error_description'
        """
        self.error_response = error_response


class AuthResponseHandler:
    def __init__(self, client):
        """
        Args:
            client (flask_pyoidc.pyoidc_facade.PyoidcFacade): Client proxy to make requests to the provider
        """
        self._client = client

    def process_auth_response(self, auth_response: Union[AuthorizationResponse, AuthorizationErrorResponse], expected_state: str, expected_nonce: str=None):
        """
        Args:
            auth_response (Union[AuthorizationResponse, AuthorizationErrorResponse]): parsed OIDC auth response
            expected_state (str): state value included in the OIDC auth request
            expected_nonce (str): nonce value included in the OIDC auth request
        Returns:
            AuthenticationResult: All relevant data associated with the authenticated user
        """
        if isinstance(auth_response, AuthorizationErrorResponse):
            raise AuthResponseErrorResponseError(auth_response.to_dict())

        if auth_response['state'] != expected_state:
            raise AuthResponseUnexpectedStateError()

        # implicit/hybrid flow may return tokens in the auth response
        access_token = auth_response.get('access_token', None)
        refresh_token = auth_response.get('refresh_token', None)
        id_token_claims = auth_response['id_token'].to_dict() if 'id_token' in auth_response else None
        id_token_jwt = auth_response.get('id_token_jwt', None) if 'id_token_jwt' in auth_response else None

        if 'code' in auth_response:
            token_resp = self._client.token_request(auth_response['code'])
            if token_resp:
                if 'error' in token_resp:
                    raise AuthResponseErrorResponseError(token_resp.to_dict())

                access_token = token_resp['access_token']

                if 'id_token' in token_resp:
                    id_token = token_resp['id_token']
                    logger.debug('received id token: %s', id_token.to_json())

                    if id_token['nonce'] != expected_nonce:
                        raise AuthResponseUnexpectedNonceError()

                    id_token_claims = id_token.to_dict()
                    id_token_jwt = token_resp.get('id_token_jwt')

                if 'refresh_token' in token_resp:
                    refresh_token = token_resp['refresh_token']
                    logger.debug('received refresh token: %s', refresh_token)

        # do userinfo request
        userinfo = self._client.userinfo_request(access_token)
        userinfo_claims = None
        if userinfo:
            userinfo_claims = userinfo.to_dict()

        if id_token_claims and userinfo_claims and userinfo_claims['sub'] != id_token_claims['sub']:
            raise AuthResponseMismatchingSubjectError('The \'sub\' of userinfo does not match \'sub\' of ID Token.')

        return AuthenticationResult(access_token, id_token_claims, id_token_jwt, userinfo_claims, refresh_token)

    @classmethod
    def expect_fragment_encoded_response(cls, auth_request):
        if 'response_mode' in auth_request:
            return auth_request['response_mode'] == 'fragment'

        response_type = set(auth_request['response_type'].split(' '))
        is_implicit_flow = response_type == {'id_token'} or \
                           response_type == {'id_token', 'token'}
        is_hybrid_flow = response_type == {'code', 'id_token'} or \
                         response_type == {'code', 'token'} or \
                         response_type == {'code', 'id_token', 'token'}

        return is_implicit_flow or is_hybrid_flow