Skip to content

Commit 227d76c

Browse files
authored
update token handling logic (#83)
* Update token store, refresh, whitelist * Update the token handling logic when update a user's multi-login status * Delete all tokens when the user delete * Fix the user logout interface * Fix the get refresh token interface * Fix multi-point login judgement when creat token * Update the refresh_token interface * Fix redis prefix deletion exclusion * Fix token deletion error when user deleted * Update the token time base to datetime.now() * Update limiter storage prefix to settings * Fix user login time not updated to database * Allowing a user to have multiple refresh tokens. * Add code comment to user multi-login update method * Remove refresh token get and create interface * Add user update multipoint login delete refresh token
1 parent fcc874b commit 227d76c

File tree

13 files changed

+255
-165
lines changed

13 files changed

+255
-165
lines changed

backend/app/.env.example

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,3 @@ APS_REDIS_PASSWORD=''
1717
APS_REDIS_DATABASE=1
1818
# Token
1919
TOKEN_SECRET_KEY='1VkVF75nsNABBjK_7-qz7GtzNy3AMvktc9TCPwKczCk'
20-
TOKEN_WHITE_LIST=[1]

backend/app/api/v1/auth/auth.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from fastapi import APIRouter, Depends, Request
3+
from typing import Annotated
4+
5+
from fastapi import APIRouter, Depends, Request, Query
46
from fastapi.security import OAuth2PasswordRequestForm
57
from fastapi_limiter.depends import RateLimiter
68
from starlette.background import BackgroundTasks
79

8-
from backend.app.common.jwt import DependsUser, get_token, jwt_decode, CurrentJwtAuth
10+
from backend.app.common.jwt import DependsUser, CurrentUser
911
from backend.app.common.response.response_schema import response_base
10-
from backend.app.schemas.token import RefreshToken, LoginToken, SwaggerToken, RefreshTokenTime
12+
from backend.app.schemas.token import LoginToken, SwaggerToken, NewToken
1113
from backend.app.schemas.user import Auth
12-
from backend.app.services.user_service import UserService
14+
from backend.app.services.auth_service import AuthService
1315

1416
router = APIRouter()
1517

1618

1719
@router.post('/swagger_login', summary='swagger 表单登录', description='form 格式登录,仅用于 swagger 文档调试接口')
1820
async def swagger_user_login(form_data: OAuth2PasswordRequestForm = Depends()) -> SwaggerToken:
19-
token, user = await UserService().swagger_login(form_data)
21+
token, user = await AuthService().swagger_login(form_data)
2022
return SwaggerToken(access_token=token, user=user)
2123

2224

@@ -27,7 +29,7 @@ async def swagger_user_login(form_data: OAuth2PasswordRequestForm = Depends()) -
2729
dependencies=[Depends(RateLimiter(times=5, minutes=15))],
2830
)
2931
async def user_login(request: Request, obj: Auth, background_tasks: BackgroundTasks):
30-
access_token, refresh_token, access_expire, refresh_expire, user = await UserService().login(
32+
access_token, refresh_token, access_expire, refresh_expire, user = await AuthService().login(
3133
request=request, obj=obj, background_tasks=background_tasks
3234
)
3335
data = LoginToken(
@@ -40,17 +42,14 @@ async def user_login(request: Request, obj: Auth, background_tasks: BackgroundTa
4042
return response_base.success(data=data)
4143

4244

43-
@router.post('/refresh_token', summary='刷新 token', dependencies=[DependsUser])
44-
async def get_refresh_token(request: Request, custom_time: RefreshTokenTime):
45-
token = get_token(request)
46-
user_id, _ = jwt_decode(token)
47-
refresh_token, refresh_expire = await UserService.refresh_token(user_id=user_id, custom_time=custom_time)
48-
data = RefreshToken(refresh_token=refresh_token, refresh_token_expire_time=refresh_expire)
45+
@router.post('/new_token', summary='创建新 token', dependencies=[DependsUser])
46+
async def create_new_token(refresh_token: Annotated[str, Query(...)]):
47+
access_token, access_expire = await AuthService.new_token(refresh_token)
48+
data = NewToken(access_token=access_token, access_token_expire_time=access_expire)
4949
return response_base.success(data=data)
5050

5151

52-
@router.post('/logout', summary='用户登出', dependencies=[DependsUser])
53-
async def user_logout(jwt: CurrentJwtAuth):
54-
user_id = jwt.get('sub')
55-
await UserService.logout(user_id)
52+
@router.post('/logout', summary='用户登出')
53+
async def user_logout(request: Request, current_user: CurrentUser):
54+
await AuthService.logout(request=request, current_user=current_user)
5655
return response_base.success()

backend/app/api/v1/user.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33
from typing import Annotated
44

5-
from fastapi import APIRouter, Query
5+
from fastapi import APIRouter, Query, Request
66

77
from backend.app.common.jwt import DependsUser, CurrentUser, DependsSuperUser
88
from backend.app.common.pagination import paging_data, PageDepends
@@ -54,10 +54,10 @@ async def update_avatar(username: str, avatar: Avatar, current_user: CurrentUser
5454

5555
@router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsUser, PageDepends])
5656
async def get_all_users(
57-
db: CurrentSession,
58-
username: Annotated[str | None, Query()] = None,
59-
phone: Annotated[str | None, Query()] = None,
60-
status: Annotated[int | None, Query()] = None,
57+
db: CurrentSession,
58+
username: Annotated[str | None, Query()] = None,
59+
phone: Annotated[str | None, Query()] = None,
60+
status: Annotated[int | None, Query()] = None,
6161
):
6262
user_select = await UserService.get_select(username=username, phone=phone, status=status)
6363
page_data = await paging_data(db, user_select, GetAllUserInfo)
@@ -80,6 +80,14 @@ async def active_set(pk: int):
8080
return response_base.fail()
8181

8282

83+
@router.post('/{pk}/multi', summary='修改用户多点登录状态')
84+
async def multi_set(request: Request, pk: int, current_user: CurrentUser):
85+
count = await UserService.update_multi_login(request=request, pk=pk, current_user=current_user)
86+
if count > 0:
87+
return response_base.success()
88+
return response_base.fail()
89+
90+
8391
@router.delete('/{username}', summary='用户注销', description='用户注销 != 用户退出,注销之后用户将从数据库删除')
8492
async def delete_user(username: str, current_user: CurrentUser):
8593
count = await UserService.delete(username=username, current_user=current_user)

backend/app/common/jwt.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from backend.app.crud.crud_user import UserDao
1717
from backend.app.database.db_mysql import CurrentSession
1818
from backend.app.models import User
19-
from backend.app.schemas.token import RefreshTokenTime
2019

2120
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
2221

@@ -53,46 +52,60 @@ async def create_access_token(sub: str, expires_delta: timedelta | None = None,
5352
:return:
5453
"""
5554
if expires_delta:
56-
expire = datetime.utcnow() + expires_delta
55+
expire = datetime.now() + expires_delta
5756
expire_seconds = int(expires_delta.total_seconds())
5857
else:
59-
expire = datetime.utcnow() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)
58+
expire = datetime.now() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)
6059
expire_seconds = settings.TOKEN_EXPIRE_SECONDS
60+
multi_login = kwargs.pop('multi_login', None)
6161
to_encode = {'exp': expire, 'sub': sub, **kwargs}
6262
token = jwt.encode(to_encode, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM)
63-
if sub not in settings.TOKEN_WHITE_LIST:
64-
await redis_client.delete_prefix(f'{settings.TOKEN_REDIS_PREFIX}:{sub}:')
63+
if multi_login is False:
64+
prefix = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:'
65+
await redis_client.delete_prefix(prefix)
6566
key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{token}'
6667
await redis_client.setex(key, expire_seconds, token)
6768
return token, expire
6869

6970

70-
async def create_refresh_token(
71-
sub: str, expire_time: datetime | None = None, custom_expire_time: RefreshTokenTime | None = None, **kwargs
72-
) -> tuple[str, datetime]:
71+
async def create_refresh_token(sub: str, expire_time: datetime | None = None, **kwargs) -> tuple[str, datetime]:
7372
"""
74-
Generate encryption refresh token
73+
Generate encryption refresh token, only used to create a new token
7574
7675
:param sub: The subject/userid of the JWT
7776
:param expire_time: expiry time
78-
:param custom_expire_time: custom expiry time
7977
:return:
8078
"""
8179
if expire_time:
82-
expire = expire_time + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)
83-
expire_seconds = int((expire - datetime.utcnow()).total_seconds())
84-
elif custom_expire_time:
85-
expire = custom_expire_time.expire_time
86-
expire_seconds = int((expire - datetime.utcnow()).total_seconds())
80+
expire = expire_time + timedelta(seconds=settings.TOKEN_REFRESH_EXPIRE_SECONDS)
81+
expire_seconds = int((expire - datetime.now()).total_seconds())
8782
else:
88-
expire = datetime.utcnow() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)
89-
expire_seconds = settings.TOKEN_EXPIRE_SECONDS
83+
expire = datetime.now() + timedelta(seconds=settings.TOKEN_REFRESH_EXPIRE_SECONDS)
84+
expire_seconds = settings.TOKEN_REFRESH_EXPIRE_SECONDS
85+
multi_login = kwargs.pop('multi_login', None)
9086
to_encode = {'exp': expire, 'sub': sub, **kwargs}
91-
token = jwt.encode(to_encode, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM)
92-
# 刷新 token 时,保持旧 token 有效,不执行删除操作
93-
key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{token}'
94-
await redis_client.setex(key, expire_seconds, token)
95-
return token, expire
87+
refresh_token = jwt.encode(to_encode, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM)
88+
if multi_login is False:
89+
prefix = f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{sub}:'
90+
await redis_client.delete_prefix(prefix)
91+
key = f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{sub}:{refresh_token}'
92+
await redis_client.setex(key, expire_seconds, refresh_token)
93+
return refresh_token, expire
94+
95+
96+
async def create_new_token(sub: str, refresh_token: str, **kwargs) -> tuple[str, datetime]:
97+
"""
98+
Generate new token
99+
100+
:param sub:
101+
:param refresh_token:
102+
:return:
103+
"""
104+
redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{sub}:{refresh_token}')
105+
if not redis_refresh_token or redis_refresh_token != refresh_token:
106+
raise TokenError(msg='refresh_token 已过期')
107+
new_token, expire = await create_access_token(sub, **kwargs)
108+
return new_token, expire
96109

97110

98111
def get_token(request: Request) -> str:
@@ -102,10 +115,10 @@ def get_token(request: Request) -> str:
102115
:return:
103116
"""
104117
authorization = request.headers.get('Authorization')
105-
scheme, param = get_authorization_scheme_param(authorization)
118+
scheme, token = get_authorization_scheme_param(authorization)
106119
if not authorization or scheme.lower() != 'bearer':
107120
raise TokenError
108-
return param
121+
return token
109122

110123

111124
def jwt_decode(token: str) -> tuple[int, list[int]]:
@@ -118,12 +131,12 @@ def jwt_decode(token: str) -> tuple[int, list[int]]:
118131
try:
119132
payload = jwt.decode(token, settings.TOKEN_SECRET_KEY, algorithms=[settings.TOKEN_ALGORITHM])
120133
user_id = int(payload.get('sub'))
121-
user_roles = list(payload.get('role_ids'))
122-
if not user_id or not user_roles:
134+
role_ids = list(payload.get('role_ids'))
135+
if not user_id or not role_ids:
123136
raise TokenError
124137
except (jwt.JWTError, ValidationError, Exception):
125138
raise TokenError
126-
return user_id, user_roles
139+
return user_id, role_ids
127140

128141

129142
async def jwt_authentication(token: str = Depends(oauth2_schema)) -> dict[str, int]:
@@ -137,7 +150,7 @@ async def jwt_authentication(token: str = Depends(oauth2_schema)) -> dict[str, i
137150
key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token}'
138151
token_verify = await redis_client.get(key)
139152
if not token_verify:
140-
raise TokenError
153+
raise TokenError(msg='token 已过期')
141154
return {'sub': user_id}
142155

143156

backend/app/common/redis.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,26 @@ async def open(self):
3737
log.error('❌ 数据库 redis 连接异常 {}', e)
3838
sys.exit()
3939

40-
async def delete_prefix(self, key: str):
40+
async def delete_prefix(self, prefix: str, exclude: str | list = None):
4141
"""
4242
删除指定前缀的所有key
4343
44-
:param key:
44+
:param prefix:
45+
:param exclude:
4546
:return:
4647
"""
47-
keys = await self.keys(f'{key}*')
48-
if keys:
49-
await self.delete(*keys)
48+
keys = []
49+
async for key in self.scan_iter(match=f'{prefix}*'):
50+
if isinstance(exclude, str):
51+
if key != exclude:
52+
keys.append(key)
53+
elif isinstance(exclude, list):
54+
if key not in exclude:
55+
keys.append(key)
56+
else:
57+
keys.append(key)
58+
for key in keys:
59+
await self.delete(key)
5060

5161

5262
# 创建redis连接对象

backend/app/core/conf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class Settings(BaseSettings):
3030

3131
# Env Token
3232
TOKEN_SECRET_KEY: str # 密钥 secrets.token_urlsafe(32))
33-
TOKEN_WHITE_LIST: list[str] # 白名单用户ID,可多点登录
3433

3534
# FastAPI
3635
API_V1_STR: str = '/v1'
@@ -58,6 +57,9 @@ def validator_api_url(cls, values):
5857
# Location Parse
5958
LOCATION_PARSE: Literal['online', 'offline', 'false'] = 'offline'
6059

60+
# Limiter
61+
LIMITER_REDIS_PREFIX: str = 'fba_limiter'
62+
6163
# MySQL
6264
DB_ECHO: bool = False
6365
DB_DATABASE: str = 'fba'
@@ -77,8 +79,10 @@ def validator_api_url(cls, values):
7779
# Token
7880
TOKEN_ALGORITHM: str = 'HS256' # 算法
7981
TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 1 # 过期时间,单位:秒
82+
TOKEN_REFRESH_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # 刷新过期时间,单位:秒
8083
TOKEN_URL_SWAGGER: str = f'{API_V1_STR}/auth/swagger_login'
8184
TOKEN_REDIS_PREFIX: str = 'fba_token'
85+
TOKEN_REFRESH_REDIS_PREFIX: str = 'fba_refresh_token'
8286

8387
# Log
8488
LOG_STDOUT_FILENAME: str = 'fba_access.log'

backend/app/core/registrar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def register_init(app: FastAPI):
2828
# 连接 redis
2929
await redis_client.open()
3030
# 初始化 limiter
31-
await FastAPILimiter.init(redis_client, prefix='fba_limiter')
31+
await FastAPILimiter.init(redis_client, prefix=settings.LIMITER_REDIS_PREFIX)
3232
# 启动定时任务
3333
scheduler.start()
3434

backend/app/crud/crud_user.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ async def get_by_username(self, db: AsyncSession, username: str) -> User | None:
2424

2525
async def update_login_time(self, db: AsyncSession, username: str, login_time: datetime) -> int:
2626
user = await db.execute(update(self.model).where(self.model.username == username).values(last_login=login_time))
27+
await db.commit()
2728
return user.rowcount
2829

2930
async def create(self, db: AsyncSession, create: CreateUser) -> NoReturn:
@@ -91,6 +92,10 @@ async def get_active(self, db: AsyncSession, user_id: int) -> bool:
9192
user = await self.get(db, user_id)
9293
return user.is_active
9394

95+
async def get_multi_login(self, db: AsyncSession, user_id: int) -> bool:
96+
user = await self.get(db, user_id)
97+
return user.is_multi_login
98+
9499
async def set_super(self, db: AsyncSession, user_id: int) -> int:
95100
super_status = await self.get_super(db, user_id)
96101
user = await db.execute(
@@ -105,6 +110,13 @@ async def set_active(self, db: AsyncSession, user_id: int) -> int:
105110
)
106111
return user.rowcount
107112

113+
async def set_multi_login(self, db: AsyncSession, user_id: int) -> int:
114+
multi_login = await self.get_multi_login(db, user_id)
115+
user = await db.execute(
116+
update(self.model).where(self.model.id == user_id).values(is_multi_login=False if multi_login else True)
117+
)
118+
return user.rowcount
119+
108120
async def get_role_ids(self, db: AsyncSession, user_id: int) -> list[int]:
109121
user = await db.execute(
110122
select(self.model).where(self.model.id == user_id).options(selectinload(self.model.roles))

backend/app/models/sys_user.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class User(DataClassBase):
2323
email: Mapped[str] = mapped_column(String(50), unique=True, index=True, comment='邮箱')
2424
is_superuser: Mapped[bool] = mapped_column(default=False, comment='超级权限(0否 1是)')
2525
is_active: Mapped[bool] = mapped_column(default=True, comment='用户账号状态(0停用 1正常)')
26+
is_multi_login: Mapped[bool] = mapped_column(default=False, comment='是否重复登陆(0否 1是)')
2627
avatar: Mapped[str | None] = mapped_column(String(255), default=None, comment='头像')
2728
phone: Mapped[str | None] = mapped_column(String(11), default=None, comment='手机号')
2829
time_joined: Mapped[datetime] = mapped_column(init=False, default=func.now(), comment='注册时间')

backend/app/schemas/token.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from datetime import datetime, timedelta
3+
from datetime import datetime
44

5-
from pydantic import BaseModel, Field, validator
6-
from pydantic.datetime_parse import parse_datetime
5+
from pydantic import BaseModel
76

87
from backend.app.schemas.user import GetUserInfoNoRelation
98

@@ -14,10 +13,13 @@ class SwaggerToken(BaseModel):
1413
user: GetUserInfoNoRelation
1514

1615

17-
class LoginToken(BaseModel):
16+
class AccessToken(BaseModel):
1817
access_token: str
1918
access_token_type: str = 'Bearer'
2019
access_token_expire_time: datetime
20+
21+
22+
class LoginToken(AccessToken):
2123
refresh_token: str
2224
refresh_token_type: str = 'Bearer'
2325
refresh_token_expire_time: datetime
@@ -30,20 +32,5 @@ class RefreshToken(BaseModel):
3032
refresh_token_expire_time: datetime
3133

3234

33-
class RefreshTokenTime(BaseModel):
34-
expire_time: datetime | None = Field(None, description='自定义刷新令牌过期时间')
35-
36-
@validator('expire_time', pre=True)
37-
def validate_expire_time(cls, v):
38-
if v is None:
39-
return None
40-
if not isinstance(v, str) or 'T' not in v:
41-
raise ValueError('输入时间格式错误')
42-
v = parse_datetime(v)
43-
utcnow = datetime.utcnow()
44-
no_tz_v = v.replace(tzinfo=None)
45-
if no_tz_v < utcnow:
46-
raise ValueError('输入时间小于当前时间')
47-
if no_tz_v > utcnow + timedelta(days=7):
48-
raise ValueError('输入时间大于当前时间上限 7 天')
49-
return no_tz_v
35+
class NewToken(AccessToken):
36+
pass

0 commit comments

Comments
 (0)