Skip to content

Commit 847d722

Browse files
authored
Update uniform return to custom encoder (#60)
* Add custom jsonable encoder * Update uniform return to custom encoder * Add some description of the return structure
1 parent e43f128 commit 847d722

File tree

3 files changed

+224
-13
lines changed

3 files changed

+224
-13
lines changed

backend/app/api/v1/user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ async def password_reset(obj: ResetPassword):
2929
async def userinfo(username: str):
3030
current_user = await UserService.get_userinfo(username)
3131
data = GetUserInfo(**select_to_json(current_user))
32-
return response_base.success(data=data, exclude={'password'})
32+
return response_base.success(data=data)
3333

3434

3535
@router.put('/{username}', summary='更新用户信息')
Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,34 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33
from datetime import datetime
4-
from typing import Any, Union, Set, Dict
4+
from typing import Any
55

6-
from fastapi.encoders import jsonable_encoder
76
from pydantic import validate_arguments, BaseModel
87

9-
_JsonEncoder = Union[Set[int | str], Dict[int | str, Any]]
8+
from backend.app.utils.encoders import jsonable_encoder
9+
10+
_ExcludeData = set[int | str] | dict[int | str, Any]
1011

1112
__all__ = ['ResponseModel', 'response_base']
1213

1314

1415
class ResponseModel(BaseModel):
1516
"""
16-
统一返回模型, 可在 FastAPI 接口请求中指定 response_model 及更多操作
17+
统一返回模型
18+
19+
.. tip::
20+
21+
如果你不想使用 ResponseBase 中的自定义编码器,可以使用此模型,返回数据将通过 fastapi 内部的编码器直接自动解析并返回
22+
23+
E.g. ::
24+
25+
@router.get('/test', response_model=ResponseModel)
26+
def test():
27+
return ResponseModel(data={'test': 'test'})
28+
29+
@router.get('/test')
30+
def test() -> ResponseModel:
31+
return ResponseModel(data={'test': 'test'})
1732
"""
1833

1934
code: int = 200
@@ -25,13 +40,30 @@ class Config:
2540

2641

2742
class ResponseBase:
43+
"""
44+
统一返回方法
45+
46+
.. tip::
47+
48+
此类中的返回方法将通过自定义编码器预解析,然后由 fastapi 内部的编码器再次处理并返回,可能存在性能损耗,取决于个人喜好
49+
50+
E.g. ::
51+
52+
@router.get('/test')
53+
def test():
54+
return response_base.success(data={'test': 'test'})
55+
"""
2856
@staticmethod
29-
def __encode_json(data: Any):
30-
return jsonable_encoder(data, custom_encoder={datetime: lambda x: x.strftime('%Y-%m-%d %H:%M:%S')})
57+
def __json_encoder(data: Any, exclude: _ExcludeData | None = None, **kwargs):
58+
custom_encoder = {datetime: lambda x: x.strftime('%Y-%m-%d %H:%M:%S')}
59+
kwargs.update({'custom_encoder': custom_encoder})
60+
return jsonable_encoder(data, exclude=exclude, **kwargs)
3161

3262
@staticmethod
3363
@validate_arguments
34-
def success(*, code: int = 200, msg: str = 'Success', data: Any | None = None, exclude: _JsonEncoder | None = None):
64+
def success(
65+
*, code: int = 200, msg: str = 'Success', data: Any | None = None, exclude: _ExcludeData | None = None, **kwargs
66+
) -> dict:
3567
"""
3668
请求成功返回通用方法
3769
@@ -41,14 +73,16 @@ def success(*, code: int = 200, msg: str = 'Success', data: Any | None = None, e
4173
:param exclude: 排除返回数据(data)字段
4274
:return:
4375
"""
44-
data = data if data is None else ResponseBase.__encode_json(data)
45-
return ResponseModel(code=code, msg=msg, data=data).dict(exclude={'data': exclude})
76+
data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs)
77+
return {'code': code, 'msg': msg, 'data': data}
4678

4779
@staticmethod
4880
@validate_arguments
49-
def fail(*, code: int = 400, msg: str = 'Bad Request', data: Any = None, exclude: _JsonEncoder | None = None):
50-
data = data if data is None else ResponseBase.__encode_json(data)
51-
return ResponseModel(code=code, msg=msg, data=data).dict(exclude={'data': exclude})
81+
def fail(
82+
*, code: int = 400, msg: str = 'Bad Request', data: Any = None, exclude: _ExcludeData | None = None, **kwargs
83+
) -> dict:
84+
data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs)
85+
return {'code': code, 'msg': msg, 'data': data}
5286

5387

5488
response_base = ResponseBase()

backend/app/utils/encoders.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
import dataclasses
4+
from collections import defaultdict
5+
from enum import Enum
6+
from pathlib import PurePath
7+
from types import GeneratorType
8+
from typing import Any, Callable, Iterable
9+
10+
from pydantic import BaseModel
11+
from pydantic.json import ENCODERS_BY_TYPE
12+
13+
SetIntStr = set[int | str]
14+
DictIntStrAny = dict[int | str, Any]
15+
16+
PRIMITIVE_TYPE = (str, bool, int, float, type(None))
17+
ARRAY_TYPES = (list, set, frozenset, GeneratorType, tuple)
18+
19+
20+
def _generate_encoders_by_class_tuples(
21+
type_encoder_map: dict[Any, Callable[[Any], Any]]
22+
) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
23+
encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple)
24+
for type_, encoder in type_encoder_map.items():
25+
encoders_by_class_tuples[encoder] += (type_,)
26+
return encoders_by_class_tuples
27+
28+
29+
encoders_by_class_tuples = _generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
30+
31+
32+
def jsonable_encoder(
33+
obj: Any,
34+
include: SetIntStr | DictIntStrAny | None = None,
35+
exclude: SetIntStr | DictIntStrAny | None = None,
36+
by_alias: bool = True,
37+
exclude_unset: bool = False,
38+
exclude_defaults: bool = False,
39+
exclude_none: bool = False,
40+
custom_encoder: dict[Any, Callable[[Any], Any]] | None = None,
41+
sqlalchemy_safe: bool = True,
42+
) -> Any:
43+
custom_encoder = custom_encoder or {}
44+
if custom_encoder:
45+
if type(obj) in custom_encoder:
46+
return custom_encoder[type(obj)](obj)
47+
else:
48+
for encoder_type, encoder_instance in custom_encoder.items():
49+
if isinstance(obj, encoder_type):
50+
return encoder_instance(obj)
51+
if include is not None and not isinstance(include, (set, dict)):
52+
include = set(include)
53+
if exclude is not None and not isinstance(exclude, (set, dict)):
54+
exclude = set(exclude)
55+
56+
def encode_dict(obj: Any) -> Any:
57+
encoded_dict = {}
58+
allowed_keys = set(obj.keys())
59+
if include is not None:
60+
allowed_keys &= set(include)
61+
if exclude is not None:
62+
allowed_keys -= set(exclude)
63+
64+
for key, value in obj.items():
65+
if (
66+
(not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith('_sa')))
67+
and (value is not None or not exclude_none)
68+
and key in allowed_keys
69+
):
70+
if isinstance(key, PRIMITIVE_TYPE):
71+
encoded_key = key
72+
else:
73+
encoded_key = jsonable_encoder(
74+
key,
75+
by_alias=by_alias,
76+
exclude_unset=exclude_unset,
77+
exclude_none=exclude_none,
78+
custom_encoder=custom_encoder,
79+
sqlalchemy_safe=sqlalchemy_safe,
80+
)
81+
encoded_value = jsonable_encoder(
82+
value,
83+
by_alias=by_alias,
84+
exclude_unset=exclude_unset,
85+
exclude_none=exclude_none,
86+
custom_encoder=custom_encoder,
87+
sqlalchemy_safe=sqlalchemy_safe,
88+
)
89+
encoded_dict[encoded_key] = encoded_value
90+
return encoded_dict
91+
92+
def encode_array(obj: Iterable[Any]) -> Any:
93+
encoded_list = []
94+
for item in obj:
95+
encoded_list.append(
96+
jsonable_encoder(
97+
item,
98+
include=include,
99+
exclude=exclude,
100+
by_alias=by_alias,
101+
exclude_unset=exclude_unset,
102+
exclude_defaults=exclude_defaults,
103+
exclude_none=exclude_none,
104+
custom_encoder=custom_encoder,
105+
sqlalchemy_safe=sqlalchemy_safe,
106+
)
107+
)
108+
return encoded_list
109+
110+
def encode_base_model(obj: BaseModel) -> Any:
111+
encoder = getattr(obj.__config__, 'json_encoders', {})
112+
if custom_encoder:
113+
encoder.update(custom_encoder)
114+
115+
obj_dict = obj.dict(
116+
include=include,
117+
exclude=exclude,
118+
by_alias=by_alias,
119+
exclude_unset=exclude_unset,
120+
exclude_none=exclude_none,
121+
exclude_defaults=exclude_defaults,
122+
)
123+
if '__root__' in obj_dict:
124+
obj_dict = obj_dict['__root__']
125+
126+
return jsonable_encoder(
127+
obj_dict,
128+
exclude_none=exclude_none,
129+
exclude_defaults=exclude_defaults,
130+
custom_encoder=encoder,
131+
sqlalchemy_safe=sqlalchemy_safe,
132+
)
133+
134+
# Use type comparisons on common types before expensive isinstance checks
135+
if type(obj) in PRIMITIVE_TYPE:
136+
return obj
137+
if type(obj) == dict:
138+
return encode_dict(obj)
139+
if type(obj) in ARRAY_TYPES:
140+
return encode_array(obj)
141+
142+
if isinstance(obj, BaseModel):
143+
return encode_base_model(obj)
144+
if dataclasses.is_dataclass(obj):
145+
obj_dict = dataclasses.asdict(obj)
146+
return encode_dict(obj_dict)
147+
if isinstance(obj, Enum):
148+
return obj.value
149+
if isinstance(obj, PurePath):
150+
return str(obj)
151+
152+
# Back up for Inherited types
153+
if isinstance(obj, PRIMITIVE_TYPE):
154+
return obj
155+
if isinstance(obj, dict):
156+
return encode_dict(obj)
157+
if isinstance(obj, ARRAY_TYPES):
158+
return encode_array(obj)
159+
160+
if type(obj) in ENCODERS_BY_TYPE:
161+
return ENCODERS_BY_TYPE[type(obj)](obj)
162+
for encoder, classes_tuple in encoders_by_class_tuples.items():
163+
if isinstance(obj, classes_tuple):
164+
return encoder(obj)
165+
166+
try:
167+
data = dict(obj)
168+
except Exception as e:
169+
errors: list[Exception] = []
170+
errors.append(e)
171+
try:
172+
data = vars(obj)
173+
except Exception as e:
174+
errors.append(e)
175+
raise ValueError(errors)
176+
177+
return encode_dict(data)

0 commit comments

Comments
 (0)