Skip to content

Optimize token detection and caching logic #677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions backend/app/admin/api/v1/monitor/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def append_token_detail() -> None:
for key in token_keys:
token = await redis_client.get(key)
token_payload = jwt_decode(token)
user_id = token_payload.id
session_uuid = token_payload.session_uuid
token_detail = GetTokenDetail(
id=token_payload.id,
id=user_id,
session_uuid=session_uuid,
username='未知',
nickname='未知',
Expand All @@ -58,7 +59,7 @@ def append_token_detail() -> None:
last_login_time='未知',
expire_time=token_payload.expire_time,
)
extra_info = await redis_client.get(f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{session_uuid}')
extra_info = await redis_client.get(f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:{session_uuid}')
if extra_info:
extra_info = json.loads(extra_info)
# 排除 swagger 登录生成的 token
Expand Down Expand Up @@ -87,5 +88,5 @@ async def delete_session(
session_uuid: Annotated[str, Query(description='会话 UUID')],
) -> ResponseModel:
superuser_verify(request)
await revoke_token(str(pk), session_uuid)
await revoke_token(pk, session_uuid)
return response_base.success()
47 changes: 24 additions & 23 deletions backend/app/admin/service/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ async def swagger_login(self, *, obj: HTTPBasicCredentials) -> tuple[str, User]:
async with async_db_session.begin() as db:
user = await self.user_verify(db, obj.username, obj.password)
await user_dao.update_login_time(db, obj.username)
a_token = await create_access_token(
str(user.id),
access_token = await create_access_token(
user.id,
user.is_multi_login,
# extra info
swagger=True,
)
return a_token.access_token, user
return access_token.access_token, user

async def login(
self, *, request: Request, response: Response, obj: AuthLoginParam, background_tasks: BackgroundTasks
Expand All @@ -99,24 +99,24 @@ async def login(
await redis_client.delete(f'{settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{request.state.ip}')
await user_dao.update_login_time(db, obj.username)
await db.refresh(user)
a_token = await create_access_token(
str(user.id),
access_token = await create_access_token(
user.id,
user.is_multi_login,
# extra info
username=user.username,
nickname=user.nickname,
last_login_time=timezone.t_str(user.last_login_time),
last_login_time=timezone.to_str(user.last_login_time),
ip=request.state.ip,
os=request.state.os,
browser=request.state.browser,
device=request.state.device,
)
r_token = await create_refresh_token(str(user.id), user.is_multi_login)
refresh_token = await create_refresh_token(access_token.session_uuid, user.id, user.is_multi_login)
response.set_cookie(
key=settings.COOKIE_REFRESH_TOKEN_KEY,
value=r_token.refresh_token,
value=refresh_token.refresh_token,
max_age=settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
expires=timezone.f_utc(r_token.refresh_token_expire_time),
expires=timezone.to_utc(refresh_token.refresh_token_expire_time),
httponly=True,
)
except errors.NotFoundError as e:
Expand Down Expand Up @@ -155,9 +155,9 @@ async def login(
),
)
data = GetLoginToken(
access_token=a_token.access_token,
access_token_expire_time=a_token.access_token_expire_time,
session_uuid=a_token.session_uuid,
access_token=access_token.access_token,
access_token_expire_time=access_token.access_token_expire_time,
session_uuid=access_token.session_uuid,
user=user, # type: ignore
)
return data
Expand Down Expand Up @@ -198,24 +198,22 @@ async def refresh_token(*, request: Request) -> GetNewToken:
refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY)
if not refresh_token:
raise errors.TokenError(msg='Refresh Token 已过期,请重新登录')
try:
user_id = jwt_decode(refresh_token).id
except Exception:
raise errors.TokenError(msg='Refresh Token 无效')
token_payload = jwt_decode(refresh_token)
async with async_db_session() as db:
user = await user_dao.get(db, user_id)
user = await user_dao.get(db, token_payload.id)
if not user:
raise errors.NotFoundError(msg='用户名或密码有误')
raise errors.NotFoundError(msg='用户不存在')
elif not user.status:
raise errors.AuthorizationError(msg='用户已被锁定, 请联系统管理员')
new_token = await create_new_token(
user_id=str(user.id),
refresh_token=refresh_token,
multi_login=user.is_multi_login,
refresh_token,
token_payload.session_uuid,
user.id,
user.is_multi_login,
# extra info
username=user.username,
nickname=user.nickname,
last_login_time=timezone.t_str(user.last_login_time),
last_login_time=timezone.to_str(user.last_login_time),
ip=request.state.ip,
os=request.state.os,
browser=request.state.browser,
Expand All @@ -241,6 +239,7 @@ async def logout(*, request: Request, response: Response) -> None:
token = get_token(request)
token_payload = jwt_decode(token)
user_id = token_payload.id
session_uuid = token_payload.session_uuid
refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY)
except errors.TokenError:
return
Expand All @@ -249,13 +248,15 @@ async def logout(*, request: Request, response: Response) -> None:

# 清理缓存
if request.user.is_multi_login:
await redis_client.delete(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token_payload.session_uuid}')
await redis_client.delete(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{session_uuid}')
await redis_client.delete(f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:{session_uuid}')
if refresh_token:
await redis_client.delete(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{refresh_token}')
else:
key_prefix = [
f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:',
f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:',
f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:',
]
for prefix in key_prefix:
await redis_client.delete_prefix(prefix)
Expand Down
16 changes: 9 additions & 7 deletions backend/common/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ class RequestCallNext:
response: Response


@dataclasses.dataclass
class NewToken:
new_access_token: str
new_access_token_expire_time: datetime
session_uuid: str


@dataclasses.dataclass
class AccessToken:
access_token: str
Expand All @@ -54,6 +47,15 @@ class RefreshToken:
refresh_token_expire_time: datetime


@dataclasses.dataclass
class NewToken:
new_access_token: str
new_access_token_expire_time: datetime
new_refresh_token: str
new_refresh_token_expire_time: datetime
session_uuid: str


@dataclasses.dataclass
class TokenPayload:
id: int
Expand Down
67 changes: 42 additions & 25 deletions backend/common/security/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,7 @@ def jwt_encode(payload: dict[str, Any]) -> str:
:param payload: 载荷
:return:
"""
return jwt.encode(
payload,
settings.TOKEN_SECRET_KEY,
settings.TOKEN_ALGORITHM,
)
return jwt.encode(payload, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM)


def jwt_decode(token: str) -> TokenPayload:
Expand All @@ -75,20 +71,27 @@ def jwt_decode(token: str) -> TokenPayload:
:return:
"""
try:
payload = jwt.decode(token, settings.TOKEN_SECRET_KEY, algorithms=[settings.TOKEN_ALGORITHM])
session_uuid = payload.get('session_uuid') or 'debug'
payload = jwt.decode(
token,
settings.TOKEN_SECRET_KEY,
algorithms=[settings.TOKEN_ALGORITHM],
options={'verify_exp': True},
)
session_uuid = payload.get('session_uuid')
user_id = payload.get('sub')
expire_time = payload.get('exp')
if not user_id:
expire = payload.get('exp')
if not session_uuid or not user_id or not expire:
raise errors.TokenError(msg='Token 无效')
except ExpiredSignatureError:
raise errors.TokenError(msg='Token 已过期')
except (JWTError, Exception):
raise errors.TokenError(msg='Token 无效')
return TokenPayload(id=int(user_id), session_uuid=session_uuid, expire_time=expire_time)
return TokenPayload(
id=int(user_id), session_uuid=session_uuid, expire_time=timezone.from_datetime(timezone.to_utc(expire))
)


async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> AccessToken:
async def create_access_token(user_id: int, multi_login: bool, **kwargs) -> AccessToken:
"""
生成加密 token

Expand All @@ -101,8 +104,8 @@ async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> Acce
session_uuid = str(uuid4())
access_token = jwt_encode({
'session_uuid': session_uuid,
'exp': expire,
'sub': user_id,
'exp': timezone.to_utc(expire).timestamp(),
'sub': str(user_id),
})

if not multi_login:
Expand All @@ -117,68 +120,82 @@ async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> Acce
# Token 附加信息单独存储
if kwargs:
await redis_client.setex(
f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{session_uuid}',
f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:{session_uuid}',
settings.TOKEN_EXPIRE_SECONDS,
json.dumps(kwargs, ensure_ascii=False),
)

return AccessToken(access_token=access_token, access_token_expire_time=expire, session_uuid=session_uuid)


async def create_refresh_token(user_id: str, multi_login: bool) -> RefreshToken:
async def create_refresh_token(session_uuid: str, user_id: int, multi_login: bool) -> RefreshToken:
"""
生成加密刷新 token,仅用于创建新的 token

:param session_uuid: 会话 UUID
:param user_id: 用户 ID
:param multi_login: 是否允许多端登录
:return:
"""
expire = timezone.now() + timedelta(seconds=settings.TOKEN_REFRESH_EXPIRE_SECONDS)
refresh_token = jwt_encode({'exp': expire, 'sub': user_id})
refresh_token = jwt_encode({
'session_uuid': session_uuid,
'exp': timezone.to_utc(expire).timestamp(),
'sub': str(user_id),
})

if not multi_login:
key_prefix = f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}'
await redis_client.delete_prefix(key_prefix)
await redis_client.delete_prefix(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}')

await redis_client.setex(
f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{refresh_token}',
f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{session_uuid}',
settings.TOKEN_REFRESH_EXPIRE_SECONDS,
refresh_token,
)
return RefreshToken(refresh_token=refresh_token, refresh_token_expire_time=expire)


async def create_new_token(user_id: str, refresh_token: str, multi_login: bool, **kwargs) -> NewToken:
async def create_new_token(
refresh_token: str, session_uuid: str, user_id: int, multi_login: bool, **kwargs
) -> NewToken:
"""
生成新的 token

:param user_id: 用户 ID
:param refresh_token: 刷新 token
:param session_uuid: 会话 UUID
:param user_id: 用户 ID
:param multi_login: 是否允许多端登录
:param kwargs: token 附加信息
:return:
"""
redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{refresh_token}')
redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{session_uuid}')
if not redis_refresh_token or redis_refresh_token != refresh_token:
raise errors.TokenError(msg='Refresh Token 已过期,请重新登录')

await redis_client.delete(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{session_uuid}')
await redis_client.delete(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{session_uuid}')

new_access_token = await create_access_token(user_id, multi_login, **kwargs)
new_refresh_token = await create_refresh_token(new_access_token.session_uuid, user_id, multi_login)
return NewToken(
new_access_token=new_access_token.access_token,
new_access_token_expire_time=new_access_token.access_token_expire_time,
new_refresh_token=new_refresh_token.refresh_token,
new_refresh_token_expire_time=new_refresh_token.refresh_token_expire_time,
session_uuid=new_access_token.session_uuid,
)


async def revoke_token(user_id: str, session_uuid: str) -> None:
async def revoke_token(user_id: int, session_uuid: str) -> None:
"""
撤销 token

:param user_id: 用户 ID
:param session_uuid: 会话 ID
:return:
"""
token_key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{session_uuid}'
await redis_client.delete(token_key)
await redis_client.delete(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{session_uuid}')
await redis_client.delete(f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{user_id}:{session_uuid}')


def get_token(request: Request) -> str:
Expand Down
10 changes: 6 additions & 4 deletions backend/plugin/oauth2/service/oauth2_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,20 @@ async def create_with_login(

# 创建 token
access_token = await jwt.create_access_token(
str(sys_user.id),
sys_user.id,
sys_user.is_multi_login,
# extra info
username=sys_user.username,
nickname=sys_user.nickname or f'#{text_captcha(5)}',
last_login_time=timezone.t_str(timezone.now()),
last_login_time=timezone.to_str(timezone.now()),
ip=request.state.ip,
os=request.state.os,
browser=request.state.browser,
device=request.state.device,
)
refresh_token = await jwt.create_refresh_token(str(sys_user.id), multi_login=sys_user.is_multi_login)
refresh_token = await jwt.create_refresh_token(
access_token.session_uuid, sys_user.id, sys_user.is_multi_login
)
await user_dao.update_login_time(db, sys_user.username)
await db.refresh(sys_user)
login_log = dict(
Expand All @@ -118,7 +120,7 @@ async def create_with_login(
key=settings.COOKIE_REFRESH_TOKEN_KEY,
value=refresh_token.refresh_token,
max_age=settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
expires=timezone.f_utc(refresh_token.refresh_token_expire_time),
expires=timezone.to_utc(refresh_token.refresh_token_expire_time),
httponly=True,
)
data = GetLoginToken(
Expand Down
4 changes: 2 additions & 2 deletions backend/utils/server_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_service_info() -> dict[str, str | datetime]:

try:
create_time = datetime.fromtimestamp(process.create_time(), tz=tz.utc)
start_time = timezone.f_datetime(create_time)
start_time = timezone.from_datetime(create_time)
except (psutil.NoSuchProcess, OSError):
start_time = timezone.now()

Expand All @@ -164,7 +164,7 @@ def get_service_info() -> dict[str, str | datetime]:
'mem_vms': ServerInfo.format_bytes(mem_info.vms),
'mem_rss': ServerInfo.format_bytes(mem_info.rss),
'mem_free': ServerInfo.format_bytes(mem_info.vms - mem_info.rss),
'startup': timezone.t_str(start_time),
'startup': timezone.to_str(start_time),
'elapsed': elapsed,
}

Expand Down
Loading