Skip to content

Commit e43f128

Browse files
authored
add token refreshing mechanism (#62)
* add token refreshing mechanism * update token_expires to token_expire_time * Fix implicit type conversion exception catch
1 parent 6fbddba commit e43f128

File tree

8 files changed

+162
-48
lines changed

8 files changed

+162
-48
lines changed

backend/app/api/v1/auth/auth.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,47 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from fastapi import APIRouter, Depends
3+
from fastapi import APIRouter, Depends, Request
44
from fastapi.security import OAuth2PasswordRequestForm
55

6-
from backend.app.common.jwt import DependsUser, JwtAuthentication
7-
from backend.app.common.redis import redis_client
6+
from backend.app.common.jwt import DependsUser, get_token, jwt_decode, CurrentJwtAuth
87
from backend.app.common.response.response_schema import response_base
9-
from backend.app.schemas.token import Token
8+
from backend.app.schemas.token import RefreshToken, LoginToken, SwaggerToken
109
from backend.app.schemas.user import Auth
1110
from backend.app.services.user_service import UserService
1211

1312
router = APIRouter()
1413

1514

1615
@router.post('/swagger_login', summary='swagger 表单登录', description='form 格式登录,仅用于 swagger 文档调试接口')
17-
async def swagger_user_login(form_data: OAuth2PasswordRequestForm = Depends()) -> Token:
16+
async def swagger_user_login(form_data: OAuth2PasswordRequestForm = Depends()) -> SwaggerToken:
1817
token, user = await UserService.swagger_login(form_data)
19-
return Token(access_token=token, user=user)
18+
return SwaggerToken(access_token=token, user=user)
2019

2120

2221
@router.post('/login', summary='用户登录', description='json 格式登录, 仅支持在第三方api工具调试接口, 例如: postman')
2322
async def user_login(obj: Auth):
24-
token, user = await UserService.login(obj)
25-
# TODO: token 存储
26-
data = Token(access_token=token, user=user)
23+
access_token, refresh_token, access_expire, refresh_expire, user = await UserService.login(obj)
24+
data = LoginToken(
25+
access_token=access_token,
26+
refresh_token=refresh_token,
27+
access_token_expire_time=access_expire,
28+
refresh_token_expire_time=refresh_expire,
29+
user=user,
30+
)
31+
return response_base.success(data=data)
32+
33+
34+
@router.post('/refresh_token', summary='刷新 token', dependencies=[DependsUser])
35+
async def get_refresh_token(request: Request):
36+
token = get_token(request)
37+
user_id, _ = jwt_decode(token)
38+
refresh_token, refresh_expire = await UserService.refresh_token(user_id)
39+
data = RefreshToken(refresh_token=refresh_token, refresh_token_expire_time=refresh_expire)
2740
return response_base.success(data=data)
2841

2942

3043
@router.post('/logout', summary='用户登出', dependencies=[DependsUser])
31-
async def user_logout(jwt: JwtAuthentication):
32-
user_id = jwt.get('payload').get('sub')
33-
token = jwt.get('token')
34-
key = f'token:{user_id}:{token}'
35-
await redis_client.delete(key)
44+
async def user_logout(jwt: CurrentJwtAuth):
45+
user_id = jwt.get('sub')
46+
await UserService.logout(user_id)
3647
return response_base.success()

backend/app/common/casbin_rbac.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def get_casbin_enforcer(self) -> casbin.Enforcer:
2222
2323
:return:
2424
"""
25+
# TODO: https://github.com/pycasbin/async-sqlalchemy-adapter/issues/4
2526
adapter = casbin_sqlalchemy_adapter.Adapter(self._CASBIN_DATABASE_URL, db_class=CasbinRule)
2627

2728
enforcer = casbin.Enforcer(RBAC_MODEL_CONF, adapter)

backend/app/common/jwt.py

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33
from datetime import datetime, timedelta
4-
from typing import Any
54

6-
from fastapi import Depends
5+
from fastapi import Depends, Request
76
from fastapi.security import OAuth2PasswordBearer
7+
from fastapi.security.utils import get_authorization_scheme_param
88
from jose import jwt
99
from passlib.context import CryptContext
1010
from pydantic import ValidationError
@@ -43,7 +43,7 @@ def password_verify(plain_password: str, hashed_password: str) -> bool:
4343
return pwd_context.verify(plain_password, hashed_password)
4444

4545

46-
async def create_access_token(sub: int | Any, expires_delta: timedelta | None = None, **kwargs) -> str:
46+
async def create_access_token(sub: str, expires_delta: timedelta | None = None, **kwargs) -> tuple[str, datetime]:
4747
"""
4848
Generate encryption token
4949
@@ -52,41 +52,86 @@ async def create_access_token(sub: int | Any, expires_delta: timedelta | None =
5252
:return:
5353
"""
5454
if expires_delta:
55-
expires = datetime.utcnow() + expires_delta
56-
expire_seconds = expires_delta.total_seconds()
55+
expire = datetime.utcnow() + expires_delta
56+
expire_seconds = int(expires_delta.total_seconds())
5757
else:
58-
expires = datetime.utcnow() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)
58+
expire = datetime.utcnow() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)
5959
expire_seconds = settings.TOKEN_EXPIRE_SECONDS
60-
to_encode = {'exp': expires, 'sub': str(sub), **kwargs}
60+
to_encode = {'exp': expire, 'sub': sub, **kwargs}
6161
token = jwt.encode(to_encode, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM)
6262
if sub not in settings.TOKEN_WHITE_LIST:
63-
await redis_client.delete(f'token:{sub}:*')
64-
key = f'token:{sub}:{token}'
63+
await redis_client.delete_prefix(f'{settings.TOKEN_REDIS_PREFIX}:{sub}:')
64+
key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{token}'
6565
await redis_client.setex(key, expire_seconds, token)
66-
return token
66+
return token, expire
6767

6868

69-
async def jwt_authentication(token: str = Depends(oauth2_schema)):
69+
async def create_refresh_token(sub: str, expire_time: datetime | None = None, **kwargs) -> tuple[str, datetime]:
7070
"""
71-
JWT authentication
71+
Generate encryption refresh token
72+
73+
:param sub: The subject/userid of the JWT
74+
:param expire_time: expiry time
75+
:return:
76+
"""
77+
if expire_time:
78+
expires = expire_time + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)
79+
expire_seconds = int((expires - datetime.utcnow()).total_seconds())
80+
else:
81+
expires = datetime.utcnow() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)
82+
expire_seconds = settings.TOKEN_EXPIRE_SECONDS
83+
to_encode = {'exp': expires, 'sub': sub, **kwargs}
84+
token = jwt.encode(to_encode, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM)
85+
# 刷新 token 时,保持旧 token 有效,不执行删除操作
86+
key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{token}'
87+
await redis_client.setex(key, expire_seconds, token)
88+
return token, expires
89+
90+
91+
def get_token(request: Request) -> str:
92+
"""
93+
Get token for request header
94+
95+
:return:
96+
"""
97+
authorization = request.headers.get('Authorization')
98+
scheme, param = get_authorization_scheme_param(authorization)
99+
if not authorization or scheme.lower() != 'bearer':
100+
raise TokenError
101+
return param
102+
103+
104+
def jwt_decode(token: str) -> tuple[int, list[int]]:
105+
"""
106+
Decode token
72107
73108
:param token:
74109
:return:
75110
"""
76111
try:
77112
payload = jwt.decode(token, settings.TOKEN_SECRET_KEY, algorithms=[settings.TOKEN_ALGORITHM])
78-
user_id = payload.get('sub')
79-
user_role = payload.get('role_ids')
80-
if not user_id or not user_role:
81-
raise TokenError
82-
# 验证token是否有效
83-
key = f'token:{user_id}:{token}'
84-
valid_token = await redis_client.get(key)
85-
if not valid_token:
113+
user_id = int(payload.get('sub'))
114+
user_roles = list(payload.get('role_ids'))
115+
if not user_id or not user_roles:
86116
raise TokenError
87-
return {'payload': payload, 'token': token}
88-
except (jwt.JWTError, ValidationError):
117+
except (jwt.JWTError, ValidationError, Exception):
118+
raise TokenError
119+
return user_id, user_roles
120+
121+
122+
async def jwt_authentication(token: str = Depends(oauth2_schema)) -> dict[str, int]:
123+
"""
124+
JWT authentication
125+
126+
:param token:
127+
:return:
128+
"""
129+
user_id, _ = jwt_decode(token)
130+
key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token}'
131+
token_verify = await redis_client.get(key)
132+
if not token_verify:
89133
raise TokenError
134+
return {'sub': user_id}
90135

91136

92137
async def get_current_user(db: CurrentSession, data: dict = Depends(jwt_authentication)) -> User:
@@ -97,7 +142,7 @@ async def get_current_user(db: CurrentSession, data: dict = Depends(jwt_authenti
97142
:param data:
98143
:return:
99144
"""
100-
user_id = data.get('payload').get('sub')
145+
user_id = data.get('sub')
101146
user = await UserDao.get_user_with_relation(db, user_id=user_id)
102147
if not user:
103148
raise TokenError
@@ -121,7 +166,7 @@ async def get_current_is_superuser(user: User = Depends(get_current_user)):
121166
CurrentUser = Annotated[User, Depends(get_current_user)]
122167
CurrentSuperUser = Annotated[bool, Depends(get_current_is_superuser)]
123168
# Token dependency injection
124-
JwtAuthentication = Annotated[dict, Depends(jwt_authentication)]
169+
CurrentJwtAuth = Annotated[dict, Depends(jwt_authentication)]
125170
# Permission dependency injection
126171
DependsUser = Depends(get_current_user)
127172
DependsSuperUser = Depends(get_current_is_superuser)

backend/app/common/redis.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ async def open(self):
3737
log.error('❌ 数据库 redis 连接异常 {}', e)
3838
sys.exit()
3939

40+
async def delete_prefix(self, key: str):
41+
"""
42+
删除指定前缀的所有key
43+
44+
:param key:
45+
:return:
46+
"""
47+
keys = await self.keys(f'{key}*')
48+
if keys:
49+
await self.delete(*keys)
50+
4051

4152
# 创建redis连接对象
4253
redis_client = RedisCli()

backend/app/core/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def validator_api_url(cls, values):
7373
TOKEN_ALGORITHM: str = 'HS256' # 算法
7474
TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 1 # 过期时间,单位:秒
7575
TOKEN_URL_SWAGGER: str = '/v1/auth/users/swagger_login'
76+
TOKEN_REDIS_PREFIX: str = 'fba_token'
7677

7778
# Log
7879
LOG_FILE_NAME: str = 'fba.log'

backend/app/crud/crud_user.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ async def update_avatar(self, db: AsyncSession, current_user: User, avatar: Avat
5555
async def delete_user(self, db: AsyncSession, user_id: int) -> int:
5656
return await self.delete(db, user_id)
5757

58-
async def check_email(self, db: AsyncSession, email: str) -> User:
58+
async def check_email(self, db: AsyncSession, email: str) -> User | None:
5959
mail = await db.execute(select(self.model).where(self.model.email == email))
6060
return mail.scalars().first()
6161

@@ -101,7 +101,9 @@ async def get_user_role_ids(self, db: AsyncSession, user_id: int) -> list[int]:
101101
roles_id = [role.id for role in user.scalars().first().roles]
102102
return roles_id
103103

104-
async def get_user_with_relation(self, db: AsyncSession, *, user_id: int = None, username: str = None) -> User:
104+
async def get_user_with_relation(
105+
self, db: AsyncSession, *, user_id: int = None, username: str = None
106+
) -> User | None:
105107
where = 'condition'
106108
if user_id:
107109
where = 'self.model.id == user_id'

backend/app/schemas/token.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3+
from datetime import datetime
4+
35
from pydantic import BaseModel
46

57
from backend.app.schemas.user import GetUserInfoNoRelation
68

79

8-
class Token(BaseModel):
10+
class SwaggerToken(BaseModel):
911
access_token: str
1012
token_type: str = 'Bearer'
1113
user: GetUserInfoNoRelation
14+
15+
16+
class LoginToken(BaseModel):
17+
access_token: str
18+
access_token_type: str = 'Bearer'
19+
access_token_expire_time: datetime
20+
refresh_token: str
21+
refresh_token_type: str = 'Bearer'
22+
refresh_token_expire_time: datetime
23+
user: GetUserInfoNoRelation
24+
25+
26+
class RefreshToken(BaseModel):
27+
refresh_token: str
28+
refresh_token_type: str = 'Bearer'
29+
refresh_token_expire_time: datetime

backend/app/services/user_service.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from backend.app.common import jwt
77
from backend.app.common.exception import errors
8+
from backend.app.common.redis import redis_client
9+
from backend.app.core.conf import settings
810
from backend.app.crud.crud_dept import DeptDao
911
from backend.app.crud.crud_role import RoleDao
1012
from backend.app.crud.crud_user import UserDao
@@ -20,36 +22,59 @@ async def swagger_login(form_data: OAuth2PasswordRequestForm):
2022
async with async_db_session() as db:
2123
current_user = await UserDao.get_user_by_username(db, form_data.username)
2224
if not current_user:
23-
raise errors.NotFoundError(msg='用户名不存在')
25+
raise errors.NotFoundError(msg='用户不存在')
2426
elif not jwt.password_verify(form_data.password, current_user.password):
2527
raise errors.AuthorizationError(msg='密码错误')
2628
elif not current_user.is_active:
27-
raise errors.AuthorizationError(msg='该用户已被锁定,无法登录')
29+
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
2830
# 更新登陆时间
2931
await UserDao.update_user_login_time(db, form_data.username)
3032
# 查询用户角色
3133
user_role_ids = await UserDao.get_user_role_ids(db, current_user.id)
3234
# 获取最新用户信息
3335
user = await UserDao.get_user_by_id(db, current_user.id)
3436
# 创建token
35-
access_token = await jwt.create_access_token(user.id, role_ids=user_role_ids)
37+
access_token, _ = await jwt.create_access_token(str(user.id), role_ids=user_role_ids)
3638
return access_token, user
3739

3840
@staticmethod
3941
async def login(obj: Auth):
4042
async with async_db_session() as db:
4143
current_user = await UserDao.get_user_by_username(db, obj.username)
4244
if not current_user:
43-
raise errors.NotFoundError(msg='用户名不存在')
45+
raise errors.NotFoundError(msg='用户不存在')
4446
elif not jwt.password_verify(obj.password, current_user.password):
4547
raise errors.AuthorizationError(msg='密码错误')
4648
elif not current_user.is_active:
47-
raise errors.AuthorizationError(msg='该用户已被锁定,无法登录')
49+
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
4850
await UserDao.update_user_login_time(db, obj.username)
4951
user_role_ids = await UserDao.get_user_role_ids(db, current_user.id)
5052
user = await UserDao.get_user_by_id(db, current_user.id)
51-
access_token = await jwt.create_access_token(user.id, role_ids=user_role_ids)
52-
return access_token, user
53+
access_token, access_token_expire_time = await jwt.create_access_token(str(user.id), role_ids=user_role_ids)
54+
refresh_token, refresh_token_expire_time = await jwt.create_refresh_token(
55+
str(user.id), access_token_expire_time, role_ids=user_role_ids
56+
)
57+
return access_token, refresh_token, access_token_expire_time, refresh_token_expire_time, user
58+
59+
@staticmethod
60+
async def refresh_token(user_id: int):
61+
async with async_db_session() as db:
62+
current_user = await UserDao.get_user_by_id(db, user_id)
63+
if not current_user:
64+
raise errors.NotFoundError(msg='用户不存在')
65+
elif not current_user.is_active:
66+
raise errors.AuthorizationError(msg='用户已锁定, 获取失败')
67+
user_role_ids = await UserDao.get_user_role_ids(db, current_user.id)
68+
refresh_token, refresh_token_expire_time = await jwt.create_refresh_token(
69+
str(current_user.id), role_ids=user_role_ids
70+
)
71+
return refresh_token, refresh_token_expire_time
72+
73+
@staticmethod
74+
async def logout(user_id: int):
75+
key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:'
76+
await redis_client.delete_prefix(key)
77+
return
5378

5479
@staticmethod
5580
async def register(obj: CreateUser):

0 commit comments

Comments
 (0)