services.IdentityProviderContext

src/idserver/services/IdentityProviderContext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from models import IdentityProvider
import collections
import logging

import requests
from typing import Optional, List, Dict, Tuple, Any, Iterable, Mapping, MutableMapping, Iterator

logger = logging.getLogger(__name__)


class OIDCData(collections.MutableMapping):
    """
    Basic OIDC data representation providing validation of required fields.
    """

    def __init__(self, *args: List[Tuple[str, str]], **kwargs: MutableMapping[str, str]) -> None:
        """
        Args:
            args (List[Tuple[String, String]]): key-value pairs to store
            kwargs (Dict[string, string]): key-value pairs to store
        """
        self.store: MutableMapping[str, str] = {}
        self.update(dict(*args, **kwargs))

    def __getitem__(self, key: str) -> str:
        return self.store[key]

    def __setitem__(self, key: str, value: str):
        self.store[key] = value

    def __delitem__(self, key: str):
        del self.store[key]

    def __iter__(self) -> Iterator[str]:
        return iter(self.store)

    def __len__(self) -> int:
        return len(self.store)

    def __str__(self) -> str:
        data = dict(self.store)
        if 'client_secret' in data:
            data['client_secret'] = '<masked>'
        return str(data)

    def __repr__(self) -> str:
        return str(self.store)

    def __bool__(self) -> bool:
        return True

    def copy(self, **kwargs):
        """print("OIDCData.copy  called")
        values = self.to_dict()
        values.update(kwargs)
        return OIDCData(**values)"""
        raise NotImplementedError("Use OIDCData_copy instead")

    def to_dict(self) -> Dict[str, str]:
        return dict(self.store)

def OIDCData_copy(source: OIDCData, **kwargs) -> OIDCData:
    values = source.to_dict()
    values.update(kwargs)
    return OIDCData(kwargs=values)


class ProviderMetadata(OIDCData):
    def __init__(self, issuer=None, authorization_endpoint=None, jwks_uri=None, **kwargs):
        super(ProviderMetadata, self).__init__(issuer=issuer, authorization_endpoint=authorization_endpoint, jwks_uri=jwks_uri, **kwargs)


class ClientRegistrationInfo(OIDCData):
    pass


class ClientMetadata(OIDCData):
    def __init__(self, client_id=None, client_secret=None, **kwargs):
        super(ClientMetadata, self).__init__(client_id=client_id, client_secret=client_secret, **kwargs)


class IdentityProviderContext:
    DEFAULT_REQUEST_TIMEOUT = 5

    def __init__(self, identity_provider: IdentityProvider, requests_session: requests.Session=None) -> None:
        
        """Args:
            issuer (str): OP Issuer Identifier. If this is specified discovery will be used to fetch the provider
                metadata, otherwise `provider_metadata` must be specified.
            provider_metadata (ProviderMetadata): OP metadata,
            userinfo_http_method (Optional[str]): HTTP method (GET or POST) to use when sending the UserInfo Request.
                If `none` is specified, no userinfo request will be sent.
            client_registration_info (ClientRegistrationInfo): Client metadata to register your app
                dynamically with the provider. Either this or `registered_client_metadata` must be specified.
            client_metadata (ClientMetadata): Client metadata if your app is statically
                registered with the provider. Either this or `client_registration_info` must be specified.
            auth_request_params (dict): Extra parameters that should be included in the authentication request.
            session_refresh_interval_seconds (int): Length of interval (in seconds) between attempted user data
                refreshes.
            requests_session (requests.Session): custom requests object to allow for example retry handling, etc.
        """
        self.identity_provider = identity_provider
        self.redirect_uri_endpoint = identity_provider.redirect_uri.lstrip('/')
        self.issuer = identity_provider.issuer
        self.client_metadata = ClientMetadata(client_id=identity_provider.client_id, client_secret=identity_provider.client_secret)

        #if not issuer and not provider_metadata:
            #raise ValueError("Specify either 'issuer' or 'provider_metadata'.")

        #if not client_registration_info and not client_metadata:
            #raise ValueError("Specify either 'client_registration_info' or 'client_metadata'.")

        #self._issuer = issuer
        self.provider_metadata = None

        self.client_registration_info = None

        self.userinfo_endpoint_method = "GET"
        self.auth_request_params: Dict[str, Any] = {}
        self.session_refresh_interval_seconds = None

        self.requests_session = requests_session or requests.Session()

    def ensure_provider_metadata(self):
        if not self.provider_metadata:
            resp = self.requests_session \
                .get(self.issuer + '/.well-known/openid-configuration',
                     timeout=self.DEFAULT_REQUEST_TIMEOUT)
            logger.debug('Received discovery response: ' + resp.text)

            self.provider_metadata = ProviderMetadata(**resp.json())

        return self.provider_metadata

    @property
    def registered_client_metadata(self):
        return self.client_metadata

    def register_client(self, redirect_uris, extra_parameters=None):
        if not self.client_metadata:
            if 'registration_endpoint' not in self.provider_metadata:
                raise ValueError("Can't use dynamic client registration, provider metadata is missing "
                                 "'registration_endpoint'.")

            registration_request = self.client_registration_info.to_dict()
            registration_request['redirect_uris'] = redirect_uris
            if extra_parameters:
                registration_request.update(extra_parameters)

            resp = self.requests_session \
                .post(self.provider_metadata['registration_endpoint'],
                      json=registration_request,
                      timeout=self.DEFAULT_REQUEST_TIMEOUT)
            self.client_metadata = ClientMetadata(redirect_uris=redirect_uris, **resp.json())
            logger.debug('Received registration response: client_id=' + self.client_metadata['client_id'])

        return self.client_metadata