services.OpenIdService

src/idserver/services/OpenIdService.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from models import CertificatesSetting, CertificateEntry
from Cryptodome.PublicKey import RSA
from Cryptodome.PublicKey.RSA import RsaKey
import os
from typing import Optional, Dict, List, Any
import base64
import hashlib
from flask import Flask, request, redirect
from oic import rndstr
from oic.oic.provider import Provider, AuthorizationEndpoint, TokenEndpoint, EndSessionEndpoint
from oic.utils.sdb import create_session_db
from oic.utils.http_util import get_or_post, Response, SeeOther
from models import IdServerConfiguration
from oic.utils.authz import AuthzHandling
from oic.utils.keyio import keyjar_init
from oic.utils.authn.authn_context import AuthnBroker
from oic.utils.authn.client import verify_client
import json

def pyoidcWrapper(func):
    """Common wrapper for the underlying pyoidc library functions.
    Reads GET params and POST data before passing it on the library and
    converts the response from oic.utils.http_util to wsgi.
    :param func: underlying library function
    """

    def wrapper():
        data = get_or_post(request.environ)
        cookies = request.environ.get("HTTP_COOKIE", "")
        resp = func(request=data, cookie=cookies)
        if isinstance(resp, SeeOther):
            code = int(resp.status.split()[0])
            return redirect(resp.message, code)
        return resp.message, resp.status_code, resp.headers

    return wrapper

class KeyEntry:
    def __init__(self, certificate_config: CertificateEntry) -> None:
        self.certificate_config = certificate_config

        self.private_key: Optional[RsaKey]
        self.certificate: Optional[RsaKey]

    def load_certificate(self) -> None:
        if not os.path.exists(self.certificate_config.private_key):
            raise FileNotFoundError("Private key wasn't found")
        if not os.path.exists(self.certificate_config.certificate):
            raise FileNotFoundError("Certificate wasn't found")
        
        with open(self.certificate_config.private_key, 'rb') as fp_key:
            self.private_key = RSA.import_key(fp_key.read())

        with open(self.certificate_config.certificate, 'rb') as fp_cert:
            self.certificate = RSA.import_key(fp_cert.read())

    def get_certificate(self) -> RsaKey:
        if not self.certificate:
            raise Exception("Certificate not defined")
        return self.certificate

    def get_private_key(self) -> RsaKey:
        if not self.private_key:
            raise Exception("Private key not defined")
        return self.private_key

    def is_default(self) -> bool:
        return bool(self.certificate_config.default)

class OpenIdService:
    def __init__(self, configuration: IdServerConfiguration) -> None:
        self.certificates_setting = configuration.certificates
        self.configuration = configuration
        self.keys: Dict[str, KeyEntry] = {}
        self.provider: Optional[Provider] = None
        self.jwks: Dict[str, List[Dict[str, str]]] = {}
        
    def init_keys(self) -> None:
        for certificate in self.certificates_setting.certificates:
            key_entry = KeyEntry(certificate)
            key_entry.load_certificate()
            
            sha256 = hashlib.sha256()
            sha256.update(key_entry.get_certificate().export_key('DER'))

            key_id = base64.b64encode(sha256.digest()).decode('utf-8')

            self.keys[key_id] = key_entry

    def init_app(self, app: Flask) -> None:
        # adding openid endpoint
        client_db: Dict[str, Any] = {} # todo fetch from db
        # todo hard-coded client
        client_db["test"] = {
            "client_id": "test",
            "client_secret": "test",
            "redirect_uris": [("http://localhost:5001/signin-oidc", None)]
        }

        session_db = create_session_db(self.configuration.address, secret=rndstr(32), password=rndstr(32))
        
        self.provider = Provider(self.configuration.address, session_db, client_db, None, None, AuthzHandling(), verify_client, None)
        self.provider.baseurl = self.configuration.address
        openid_configuration_endpoint = '.well-known/openid-configuration'
        app.add_url_rule("/{0}".format(openid_configuration_endpoint),
                openid_configuration_endpoint,
                self._well_known_openid_configuration,
                methods=['GET'])
        

        keys: List[Dict[str, Any]] = []
        for certificate in self.configuration.certificates.certificates:
            keys.append({
                "type": "RSA",
                "key": certificate.certificate,
                "use": ["sig"]
            })

        self.jwks = keyjar_init(self.provider, keys)

        jwks_uri = ".well-known/openid-configuration/jwks"
        self.provider.jwks_uri = "{0}/{1}".format(self.configuration.address.rstrip('/'), jwks_uri)

        app.add_url_rule("/{0}".format(jwks_uri), jwks_uri, self._well_known_jwks, methods=['GET'])
        app.add_url_rule("/{0}".format(TokenEndpoint.url), TokenEndpoint.url, pyoidcWrapper(self.provider.token_endpoint), methods=['POST', 'GET'])
        app.add_url_rule("/{0}".format(AuthorizationEndpoint.url), AuthorizationEndpoint.url, pyoidcWrapper(self.provider.authorization_endpoint), methods=['POST', 'GET'])      
        app.add_url_rule("/{0}".format(EndSessionEndpoint.url), EndSessionEndpoint.url, pyoidcWrapper(self.provider.endsession_endpoint), methods=['POST', 'GET']) 

    def _token(self):
        return self.provider.token_endpoint(request)

    def _well_known_jwks(self):
        return json.dumps(self.jwks)

    def _well_known_openid_configuration(self):
        return self.provider.providerinfo_endpoint()