diff --git a/backend/app/api/v1/user.py b/backend/app/api/v1/user.py index aa129a4c..c4261d3c 100644 --- a/backend/app/api/v1/user.py +++ b/backend/app/api/v1/user.py @@ -29,7 +29,7 @@ async def password_reset(obj: ResetPassword): async def userinfo(username: str): current_user = await UserService.get_userinfo(username) data = GetUserInfo(**select_to_json(current_user)) - return response_base.success(data=data, exclude={'password'}) + return response_base.success(data=data) @router.put('/{username}', summary='更新用户信息') diff --git a/backend/app/common/response/response_schema.py b/backend/app/common/response/response_schema.py index 999982f1..8e7f1e70 100644 --- a/backend/app/common/response/response_schema.py +++ b/backend/app/common/response/response_schema.py @@ -1,19 +1,34 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from datetime import datetime -from typing import Any, Union, Set, Dict +from typing import Any -from fastapi.encoders import jsonable_encoder from pydantic import validate_arguments, BaseModel -_JsonEncoder = Union[Set[int | str], Dict[int | str, Any]] +from backend.app.utils.encoders import jsonable_encoder + +_ExcludeData = set[int | str] | dict[int | str, Any] __all__ = ['ResponseModel', 'response_base'] class ResponseModel(BaseModel): """ - 统一返回模型, 可在 FastAPI 接口请求中指定 response_model 及更多操作 + 统一返回模型 + + .. tip:: + + 如果你不想使用 ResponseBase 中的自定义编码器,可以使用此模型,返回数据将通过 fastapi 内部的编码器直接自动解析并返回 + + E.g. :: + + @router.get('/test', response_model=ResponseModel) + def test(): + return ResponseModel(data={'test': 'test'}) + + @router.get('/test') + def test() -> ResponseModel: + return ResponseModel(data={'test': 'test'}) """ code: int = 200 @@ -25,13 +40,30 @@ class Config: class ResponseBase: + """ + 统一返回方法 + + .. tip:: + + 此类中的返回方法将通过自定义编码器预解析,然后由 fastapi 内部的编码器再次处理并返回,可能存在性能损耗,取决于个人喜好 + + E.g. :: + + @router.get('/test') + def test(): + return response_base.success(data={'test': 'test'}) + """ @staticmethod - def __encode_json(data: Any): - return jsonable_encoder(data, custom_encoder={datetime: lambda x: x.strftime('%Y-%m-%d %H:%M:%S')}) + def __json_encoder(data: Any, exclude: _ExcludeData | None = None, **kwargs): + custom_encoder = {datetime: lambda x: x.strftime('%Y-%m-%d %H:%M:%S')} + kwargs.update({'custom_encoder': custom_encoder}) + return jsonable_encoder(data, exclude=exclude, **kwargs) @staticmethod @validate_arguments - def success(*, code: int = 200, msg: str = 'Success', data: Any | None = None, exclude: _JsonEncoder | None = None): + def success( + *, code: int = 200, msg: str = 'Success', data: Any | None = None, exclude: _ExcludeData | None = None, **kwargs + ) -> dict: """ 请求成功返回通用方法 @@ -41,14 +73,16 @@ def success(*, code: int = 200, msg: str = 'Success', data: Any | None = None, e :param exclude: 排除返回数据(data)字段 :return: """ - data = data if data is None else ResponseBase.__encode_json(data) - return ResponseModel(code=code, msg=msg, data=data).dict(exclude={'data': exclude}) + data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs) + return {'code': code, 'msg': msg, 'data': data} @staticmethod @validate_arguments - def fail(*, code: int = 400, msg: str = 'Bad Request', data: Any = None, exclude: _JsonEncoder | None = None): - data = data if data is None else ResponseBase.__encode_json(data) - return ResponseModel(code=code, msg=msg, data=data).dict(exclude={'data': exclude}) + def fail( + *, code: int = 400, msg: str = 'Bad Request', data: Any = None, exclude: _ExcludeData | None = None, **kwargs + ) -> dict: + data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs) + return {'code': code, 'msg': msg, 'data': data} response_base = ResponseBase() diff --git a/backend/app/utils/encoders.py b/backend/app/utils/encoders.py new file mode 100644 index 00000000..976720ac --- /dev/null +++ b/backend/app/utils/encoders.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import dataclasses +from collections import defaultdict +from enum import Enum +from pathlib import PurePath +from types import GeneratorType +from typing import Any, Callable, Iterable + +from pydantic import BaseModel +from pydantic.json import ENCODERS_BY_TYPE + +SetIntStr = set[int | str] +DictIntStrAny = dict[int | str, Any] + +PRIMITIVE_TYPE = (str, bool, int, float, type(None)) +ARRAY_TYPES = (list, set, frozenset, GeneratorType, tuple) + + +def _generate_encoders_by_class_tuples( + type_encoder_map: dict[Any, Callable[[Any], Any]] +) -> dict[Callable[[Any], Any], tuple[Any, ...]]: + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) + for type_, encoder in type_encoder_map.items(): + encoders_by_class_tuples[encoder] += (type_,) + return encoders_by_class_tuples + + +encoders_by_class_tuples = _generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) + + +def jsonable_encoder( + obj: Any, + include: SetIntStr | DictIntStrAny | None = None, + exclude: SetIntStr | DictIntStrAny | None = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, + sqlalchemy_safe: bool = True, +) -> Any: + custom_encoder = custom_encoder or {} + if custom_encoder: + if type(obj) in custom_encoder: + return custom_encoder[type(obj)](obj) + else: + for encoder_type, encoder_instance in custom_encoder.items(): + if isinstance(obj, encoder_type): + return encoder_instance(obj) + if include is not None and not isinstance(include, (set, dict)): + include = set(include) + if exclude is not None and not isinstance(exclude, (set, dict)): + exclude = set(exclude) + + def encode_dict(obj: Any) -> Any: + encoded_dict = {} + allowed_keys = set(obj.keys()) + if include is not None: + allowed_keys &= set(include) + if exclude is not None: + allowed_keys -= set(exclude) + + for key, value in obj.items(): + if ( + (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith('_sa'))) + and (value is not None or not exclude_none) + and key in allowed_keys + ): + if isinstance(key, PRIMITIVE_TYPE): + encoded_key = key + else: + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + encoded_dict[encoded_key] = encoded_value + return encoded_dict + + def encode_array(obj: Iterable[Any]) -> Any: + encoded_list = [] + for item in obj: + encoded_list.append( + jsonable_encoder( + item, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + ) + return encoded_list + + def encode_base_model(obj: BaseModel) -> Any: + encoder = getattr(obj.__config__, 'json_encoders', {}) + if custom_encoder: + encoder.update(custom_encoder) + + obj_dict = obj.dict( + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + if '__root__' in obj_dict: + obj_dict = obj_dict['__root__'] + + return jsonable_encoder( + obj_dict, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + custom_encoder=encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + + # Use type comparisons on common types before expensive isinstance checks + if type(obj) in PRIMITIVE_TYPE: + return obj + if type(obj) == dict: + return encode_dict(obj) + if type(obj) in ARRAY_TYPES: + return encode_array(obj) + + if isinstance(obj, BaseModel): + return encode_base_model(obj) + if dataclasses.is_dataclass(obj): + obj_dict = dataclasses.asdict(obj) + return encode_dict(obj_dict) + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, PurePath): + return str(obj) + + # Back up for Inherited types + if isinstance(obj, PRIMITIVE_TYPE): + return obj + if isinstance(obj, dict): + return encode_dict(obj) + if isinstance(obj, ARRAY_TYPES): + return encode_array(obj) + + if type(obj) in ENCODERS_BY_TYPE: + return ENCODERS_BY_TYPE[type(obj)](obj) + for encoder, classes_tuple in encoders_by_class_tuples.items(): + if isinstance(obj, classes_tuple): + return encoder(obj) + + try: + data = dict(obj) + except Exception as e: + errors: list[Exception] = [] + errors.append(e) + try: + data = vars(obj) + except Exception as e: + errors.append(e) + raise ValueError(errors) + + return encode_dict(data)