stores.UserStore

src/idserver/stores/UserStore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from models import User, DbUser, DbIdentityProvider, DbUserClaim
from typing import Optional, Dict
from uuid import UUID, uuid4
from .DbEngine import DbEngine
import time
from datetime import datetime
from sqlalchemy import Column

class UserStore:
    
    def get_user_by_email(self, provider_name, user_email) -> Optional[User]:
        pass

    def add_user(self, user: User) -> UUID:
        pass

class SqlAlchemyUserStore(UserStore):

    def __init__(self, db_engine: DbEngine) -> None:
        UserStore.__init__(self)
        self.db_engine = db_engine

    def get_user_by_email(self, provider_name, user_email) -> Optional[User]:
        db_session = self.db_engine.session_factory()
        provider: Optional[DbIdentityProvider] = db_session.query(DbIdentityProvider).filter(DbIdentityProvider.provider_name == provider_name).one_or_none()
        if not provider:
            return None

        user: Optional[DbUser] = db_session.query(DbUser).filter(DbUser.email == user_email and DbUser.provider_id == provider.provider_id).one_or_none()
        db_session.close()
        
        if user:
            return user.toUser()

        return None

    def add_user(self, user: User) -> UUID:
        db_session = self.db_engine.session_factory()

        provider = db_session.query(DbIdentityProvider).filter(DbIdentityProvider.provider_name == user.provider_name).one_or_none()
        if not provider:
            raise Exception("Provider not found !")
        dbUser = DbUser(subject=uuid4(),
                provider_subject=user.provider_subject, 
                email=user.email, 
                provider_id=provider.provider_id,
                given_name=user.given_name, 
                family_name=user.family_name, 
                created_on=datetime.utcnow())

        db_session.add(dbUser)
        db_session.commit()
        db_session.close()

        return UUID(dbUser.subject)

    def update_user(self, user: User) -> None:
        db_session = self.db_engine.session_factory()

        dbUser = db_session.query(DbUser).filter(DbUser.subject == user.subject).one_or_none()

        if dbUser:
            # only updating given_name, family_name and last_update 
            dbUser.given_name = user.given_name
            dbUser.family_name = user.family_name
            dbUser.last_update = datetime.utcnow()

            db_claims_dict: Dict[str, Column] = {}

            for user_claim in dbUser.claims:
                db_claims_dict[user_claim.claim_name] = user_claim
            
            for claim_name in user.claims:
                if claim_name in db_claims_dict:
                    if db_claims_dict[claim_name].claim_value != user.claims[claim_name]:
                        db_claims_dict[claim_name].claim_value = user.claims[claim_name]
                else:
                    db_session.add(DbUserClaim(user_subject=dbUser.subject, 
                                            claim_name=claim_name, 
                                            claim_value=user.claims[claim_name]))
            
            for db_claim_name in db_claims_dict:
                if db_claim_name not in user.claims:
                    # delete this claim
                    db_claim = db_claims_dict.pop(db_claim_name)
                    db_claim.delete()
            
        db_session.commit()
        db_session.close()