From 4dfec4368194a551e2fc4e4a8937b45e0636c19f Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Sat, 21 Jun 2025 15:30:13 +0800 Subject: [PATCH] Optimize token detection and caching logic --- backend/app/admin/api/v1/monitor/online.py | 7 +- backend/app/admin/service/auth_service.py | 47 ++++++------- backend/common/dataclasses.py | 16 +++-- backend/common/security/jwt.py | 67 ++++++++++++------- .../plugin/oauth2/service/oauth2_service.py | 10 +-- backend/utils/server_info.py | 4 +- backend/utils/timezone.py | 39 +++++------ 7 files changed, 105 insertions(+), 85 deletions(-) diff --git a/backend/app/admin/api/v1/monitor/online.py b/backend/app/admin/api/v1/monitor/online.py index ed3945de..5324c0d3 100644 --- a/backend/app/admin/api/v1/monitor/online.py +++ b/backend/app/admin/api/v1/monitor/online.py @@ -44,9 +44,10 @@ def append_token_detail() -> None: for key in token_keys: token = await redis_client.get(key) token_payload = jwt_decode(token) + user_id = token_payload.id session_uuid = token_payload.session_uuid token_detail = GetTokenDetail( - id=token_payload.id, + id=user_id, session_uuid=session_uuid, username='未知', nickname='未知', @@ -58,7 +59,7 @@ def append_token_detail() -> None: last_login_time='未知', expire_time=token_payload.expire_time, ) - extra_info = await redis_client.get(f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{session_uuid}') + extra_info = await redis_client.get(f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:{session_uuid}') if extra_info: extra_info = json.loads(extra_info) # 排除 swagger 登录生成的 token @@ -87,5 +88,5 @@ async def delete_session( session_uuid: Annotated[str, Query(description='会话 UUID')], ) -> ResponseModel: superuser_verify(request) - await revoke_token(str(pk), session_uuid) + await revoke_token(pk, session_uuid) return response_base.success() diff --git a/backend/app/admin/service/auth_service.py b/backend/app/admin/service/auth_service.py index 3395a6fb..85d57d28 100644 --- a/backend/app/admin/service/auth_service.py +++ b/backend/app/admin/service/auth_service.py @@ -67,13 +67,13 @@ async def swagger_login(self, *, obj: HTTPBasicCredentials) -> tuple[str, User]: async with async_db_session.begin() as db: user = await self.user_verify(db, obj.username, obj.password) await user_dao.update_login_time(db, obj.username) - a_token = await create_access_token( - str(user.id), + access_token = await create_access_token( + user.id, user.is_multi_login, # extra info swagger=True, ) - return a_token.access_token, user + return access_token.access_token, user async def login( self, *, request: Request, response: Response, obj: AuthLoginParam, background_tasks: BackgroundTasks @@ -99,24 +99,24 @@ async def login( await redis_client.delete(f'{settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{request.state.ip}') await user_dao.update_login_time(db, obj.username) await db.refresh(user) - a_token = await create_access_token( - str(user.id), + access_token = await create_access_token( + user.id, user.is_multi_login, # extra info username=user.username, nickname=user.nickname, - last_login_time=timezone.t_str(user.last_login_time), + last_login_time=timezone.to_str(user.last_login_time), ip=request.state.ip, os=request.state.os, browser=request.state.browser, device=request.state.device, ) - r_token = await create_refresh_token(str(user.id), user.is_multi_login) + refresh_token = await create_refresh_token(access_token.session_uuid, user.id, user.is_multi_login) response.set_cookie( key=settings.COOKIE_REFRESH_TOKEN_KEY, - value=r_token.refresh_token, + value=refresh_token.refresh_token, max_age=settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS, - expires=timezone.f_utc(r_token.refresh_token_expire_time), + expires=timezone.to_utc(refresh_token.refresh_token_expire_time), httponly=True, ) except errors.NotFoundError as e: @@ -155,9 +155,9 @@ async def login( ), ) data = GetLoginToken( - access_token=a_token.access_token, - access_token_expire_time=a_token.access_token_expire_time, - session_uuid=a_token.session_uuid, + access_token=access_token.access_token, + access_token_expire_time=access_token.access_token_expire_time, + session_uuid=access_token.session_uuid, user=user, # type: ignore ) return data @@ -198,24 +198,22 @@ async def refresh_token(*, request: Request) -> GetNewToken: refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY) if not refresh_token: raise errors.TokenError(msg='Refresh Token 已过期,请重新登录') - try: - user_id = jwt_decode(refresh_token).id - except Exception: - raise errors.TokenError(msg='Refresh Token 无效') + token_payload = jwt_decode(refresh_token) async with async_db_session() as db: - user = await user_dao.get(db, user_id) + user = await user_dao.get(db, token_payload.id) if not user: - raise errors.NotFoundError(msg='用户名或密码有误') + raise errors.NotFoundError(msg='用户不存在') elif not user.status: raise errors.AuthorizationError(msg='用户已被锁定, 请联系统管理员') new_token = await create_new_token( - user_id=str(user.id), - refresh_token=refresh_token, - multi_login=user.is_multi_login, + refresh_token, + token_payload.session_uuid, + user.id, + user.is_multi_login, # extra info username=user.username, nickname=user.nickname, - last_login_time=timezone.t_str(user.last_login_time), + last_login_time=timezone.to_str(user.last_login_time), ip=request.state.ip, os=request.state.os, browser=request.state.browser, @@ -241,6 +239,7 @@ async def logout(*, request: Request, response: Response) -> None: token = get_token(request) token_payload = jwt_decode(token) user_id = token_payload.id + session_uuid = token_payload.session_uuid refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY) except errors.TokenError: return @@ -249,13 +248,15 @@ async def logout(*, request: Request, response: Response) -> None: # 清理缓存 if request.user.is_multi_login: - await redis_client.delete(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token_payload.session_uuid}') + await redis_client.delete(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{session_uuid}') + await redis_client.delete(f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:{session_uuid}') if refresh_token: await redis_client.delete(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{refresh_token}') else: key_prefix = [ f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:', f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:', + f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:', ] for prefix in key_prefix: await redis_client.delete_prefix(prefix) diff --git a/backend/common/dataclasses.py b/backend/common/dataclasses.py index 35ec6a94..aed70b32 100644 --- a/backend/common/dataclasses.py +++ b/backend/common/dataclasses.py @@ -34,13 +34,6 @@ class RequestCallNext: response: Response -@dataclasses.dataclass -class NewToken: - new_access_token: str - new_access_token_expire_time: datetime - session_uuid: str - - @dataclasses.dataclass class AccessToken: access_token: str @@ -54,6 +47,15 @@ class RefreshToken: refresh_token_expire_time: datetime +@dataclasses.dataclass +class NewToken: + new_access_token: str + new_access_token_expire_time: datetime + new_refresh_token: str + new_refresh_token_expire_time: datetime + session_uuid: str + + @dataclasses.dataclass class TokenPayload: id: int diff --git a/backend/common/security/jwt.py b/backend/common/security/jwt.py index ec106065..81c6d738 100644 --- a/backend/common/security/jwt.py +++ b/backend/common/security/jwt.py @@ -60,11 +60,7 @@ def jwt_encode(payload: dict[str, Any]) -> str: :param payload: 载荷 :return: """ - return jwt.encode( - payload, - settings.TOKEN_SECRET_KEY, - settings.TOKEN_ALGORITHM, - ) + return jwt.encode(payload, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM) def jwt_decode(token: str) -> TokenPayload: @@ -75,20 +71,27 @@ def jwt_decode(token: str) -> TokenPayload: :return: """ try: - payload = jwt.decode(token, settings.TOKEN_SECRET_KEY, algorithms=[settings.TOKEN_ALGORITHM]) - session_uuid = payload.get('session_uuid') or 'debug' + payload = jwt.decode( + token, + settings.TOKEN_SECRET_KEY, + algorithms=[settings.TOKEN_ALGORITHM], + options={'verify_exp': True}, + ) + session_uuid = payload.get('session_uuid') user_id = payload.get('sub') - expire_time = payload.get('exp') - if not user_id: + expire = payload.get('exp') + if not session_uuid or not user_id or not expire: raise errors.TokenError(msg='Token 无效') except ExpiredSignatureError: raise errors.TokenError(msg='Token 已过期') except (JWTError, Exception): raise errors.TokenError(msg='Token 无效') - return TokenPayload(id=int(user_id), session_uuid=session_uuid, expire_time=expire_time) + return TokenPayload( + id=int(user_id), session_uuid=session_uuid, expire_time=timezone.from_datetime(timezone.to_utc(expire)) + ) -async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> AccessToken: +async def create_access_token(user_id: int, multi_login: bool, **kwargs) -> AccessToken: """ 生成加密 token @@ -101,8 +104,8 @@ async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> Acce session_uuid = str(uuid4()) access_token = jwt_encode({ 'session_uuid': session_uuid, - 'exp': expire, - 'sub': user_id, + 'exp': timezone.to_utc(expire).timestamp(), + 'sub': str(user_id), }) if not multi_login: @@ -117,7 +120,7 @@ async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> Acce # Token 附加信息单独存储 if kwargs: await redis_client.setex( - f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{session_uuid}', + f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:{session_uuid}', settings.TOKEN_EXPIRE_SECONDS, json.dumps(kwargs, ensure_ascii=False), ) @@ -125,51 +128,65 @@ async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> Acce return AccessToken(access_token=access_token, access_token_expire_time=expire, session_uuid=session_uuid) -async def create_refresh_token(user_id: str, multi_login: bool) -> RefreshToken: +async def create_refresh_token(session_uuid: str, user_id: int, multi_login: bool) -> RefreshToken: """ 生成加密刷新 token,仅用于创建新的 token + :param session_uuid: 会话 UUID :param user_id: 用户 ID :param multi_login: 是否允许多端登录 :return: """ expire = timezone.now() + timedelta(seconds=settings.TOKEN_REFRESH_EXPIRE_SECONDS) - refresh_token = jwt_encode({'exp': expire, 'sub': user_id}) + refresh_token = jwt_encode({ + 'session_uuid': session_uuid, + 'exp': timezone.to_utc(expire).timestamp(), + 'sub': str(user_id), + }) if not multi_login: - key_prefix = f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}' - await redis_client.delete_prefix(key_prefix) + await redis_client.delete_prefix(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}') await redis_client.setex( - f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{refresh_token}', + f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{session_uuid}', settings.TOKEN_REFRESH_EXPIRE_SECONDS, refresh_token, ) return RefreshToken(refresh_token=refresh_token, refresh_token_expire_time=expire) -async def create_new_token(user_id: str, refresh_token: str, multi_login: bool, **kwargs) -> NewToken: +async def create_new_token( + refresh_token: str, session_uuid: str, user_id: int, multi_login: bool, **kwargs +) -> NewToken: """ 生成新的 token - :param user_id: 用户 ID :param refresh_token: 刷新 token + :param session_uuid: 会话 UUID + :param user_id: 用户 ID :param multi_login: 是否允许多端登录 :param kwargs: token 附加信息 :return: """ - redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{refresh_token}') + redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{session_uuid}') if not redis_refresh_token or redis_refresh_token != refresh_token: raise errors.TokenError(msg='Refresh Token 已过期,请重新登录') + + await redis_client.delete(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{session_uuid}') + await redis_client.delete(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{session_uuid}') + new_access_token = await create_access_token(user_id, multi_login, **kwargs) + new_refresh_token = await create_refresh_token(new_access_token.session_uuid, user_id, multi_login) return NewToken( new_access_token=new_access_token.access_token, new_access_token_expire_time=new_access_token.access_token_expire_time, + new_refresh_token=new_refresh_token.refresh_token, + new_refresh_token_expire_time=new_refresh_token.refresh_token_expire_time, session_uuid=new_access_token.session_uuid, ) -async def revoke_token(user_id: str, session_uuid: str) -> None: +async def revoke_token(user_id: int, session_uuid: str) -> None: """ 撤销 token @@ -177,8 +194,8 @@ async def revoke_token(user_id: str, session_uuid: str) -> None: :param session_uuid: 会话 ID :return: """ - token_key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{session_uuid}' - await redis_client.delete(token_key) + await redis_client.delete(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{session_uuid}') + await redis_client.delete(f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:{session_uuid}') def get_token(request: Request) -> str: diff --git a/backend/plugin/oauth2/service/oauth2_service.py b/backend/plugin/oauth2/service/oauth2_service.py index 461dedcb..7ccdd045 100644 --- a/backend/plugin/oauth2/service/oauth2_service.py +++ b/backend/plugin/oauth2/service/oauth2_service.py @@ -89,18 +89,20 @@ async def create_with_login( # 创建 token access_token = await jwt.create_access_token( - str(sys_user.id), + sys_user.id, sys_user.is_multi_login, # extra info username=sys_user.username, nickname=sys_user.nickname or f'#{text_captcha(5)}', - last_login_time=timezone.t_str(timezone.now()), + last_login_time=timezone.to_str(timezone.now()), ip=request.state.ip, os=request.state.os, browser=request.state.browser, device=request.state.device, ) - refresh_token = await jwt.create_refresh_token(str(sys_user.id), multi_login=sys_user.is_multi_login) + refresh_token = await jwt.create_refresh_token( + access_token.session_uuid, sys_user.id, sys_user.is_multi_login + ) await user_dao.update_login_time(db, sys_user.username) await db.refresh(sys_user) login_log = dict( @@ -118,7 +120,7 @@ async def create_with_login( key=settings.COOKIE_REFRESH_TOKEN_KEY, value=refresh_token.refresh_token, max_age=settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS, - expires=timezone.f_utc(refresh_token.refresh_token_expire_time), + expires=timezone.to_utc(refresh_token.refresh_token_expire_time), httponly=True, ) data = GetLoginToken( diff --git a/backend/utils/server_info.py b/backend/utils/server_info.py index 9f861154..bebd3c69 100644 --- a/backend/utils/server_info.py +++ b/backend/utils/server_info.py @@ -150,7 +150,7 @@ def get_service_info() -> dict[str, str | datetime]: try: create_time = datetime.fromtimestamp(process.create_time(), tz=tz.utc) - start_time = timezone.f_datetime(create_time) + start_time = timezone.from_datetime(create_time) except (psutil.NoSuchProcess, OSError): start_time = timezone.now() @@ -164,7 +164,7 @@ def get_service_info() -> dict[str, str | datetime]: 'mem_vms': ServerInfo.format_bytes(mem_info.vms), 'mem_rss': ServerInfo.format_bytes(mem_info.rss), 'mem_free': ServerInfo.format_bytes(mem_info.vms - mem_info.rss), - 'startup': timezone.t_str(start_time), + 'startup': timezone.to_str(start_time), 'elapsed': elapsed, } diff --git a/backend/utils/timezone.py b/backend/utils/timezone.py index a2d587a8..cf255f49 100644 --- a/backend/utils/timezone.py +++ b/backend/utils/timezone.py @@ -9,58 +9,55 @@ class TimeZone: - def __init__(self, tz: str = settings.DATETIME_TIMEZONE) -> None: - """ - 初始化时区转换器 - - :param tz: 时区名称,默认为 settings.DATETIME_TIMEZONE - :return: - """ - self.tz_info = zoneinfo.ZoneInfo(tz) + def __init__(self) -> None: + """初始化时区转换器""" + self.tz_info = zoneinfo.ZoneInfo(settings.DATETIME_TIMEZONE) def now(self) -> datetime: """获取当前时区时间""" return datetime.now(self.tz_info) - def f_datetime(self, dt: datetime) -> datetime: + def from_datetime(self, t: datetime) -> datetime: """ 将 datetime 对象转换为当前时区时间 - :param dt: 需要转换的 datetime 对象 + :param t: 需要转换的 datetime 对象 :return: """ - return dt.astimezone(self.tz_info) + return t.astimezone(self.tz_info) - def f_str(self, date_str: str, format_str: str = settings.DATETIME_FORMAT) -> datetime: + def from_str(self, t_str: str, format_str: str = settings.DATETIME_FORMAT) -> datetime: """ 将时间字符串转换为当前时区的 datetime 对象 - :param date_str: 时间字符串 + :param t_str: 时间字符串 :param format_str: 时间格式字符串,默认为 settings.DATETIME_FORMAT :return: """ - return datetime.strptime(date_str, format_str).replace(tzinfo=self.tz_info) + return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info) @staticmethod - def t_str(dt: datetime, format_str: str = settings.DATETIME_FORMAT) -> str: + def to_str(t: datetime, format_str: str = settings.DATETIME_FORMAT) -> str: """ 将 datetime 对象转换为指定格式的时间字符串 - :param dt: datetime 对象 + :param t: datetime 对象 :param format_str: 时间格式字符串,默认为 settings.DATETIME_FORMAT :return: """ - return dt.strftime(format_str) + return t.strftime(format_str) @staticmethod - def f_utc(dt: datetime) -> datetime: + def to_utc(t: datetime | int) -> datetime: """ - 将 datetime 对象转换为 UTC (GMT) 时区时间 + 将 datetime 对象或时间戳转换为 UTC 时区时间 - :param dt: 需要转换的 datetime 对象 + :param t: 需要转换的 datetime 对象或时间戳 :return: """ - return dt.astimezone(datetime_timezone.utc) + if isinstance(t, datetime): + return t.astimezone(datetime_timezone.utc) + return datetime.fromtimestamp(t, tz=datetime_timezone.utc) timezone: TimeZone = TimeZone()