Skip to content

Commit 090fca7

Browse files
authored
pydantic : replace uses of __annotations__ with get_type_hints (#8474)
* pydantic : replace uses of __annotations__ with get_type_hints * pydantic : fix Python 3.9 and 3.10 support
1 parent aaab241 commit 090fca7

File tree

3 files changed

+46
-34
lines changed

3 files changed

+46
-34
lines changed

examples/pydantic_models_to_grammar.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from copy import copy
77
from enum import Enum
88
from inspect import getdoc, isclass
9-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
9+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
1010

1111
from docstring_parser import parse
1212
from pydantic import BaseModel, create_model
@@ -53,35 +53,38 @@ class PydanticDataType(Enum):
5353

5454

5555
def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
56-
if isclass(pydantic_type) and issubclass(pydantic_type, str):
56+
origin_type = get_origin(pydantic_type)
57+
origin_type = pydantic_type if origin_type is None else origin_type
58+
59+
if isclass(origin_type) and issubclass(origin_type, str):
5760
return PydanticDataType.STRING.value
58-
elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
61+
elif isclass(origin_type) and issubclass(origin_type, bool):
5962
return PydanticDataType.BOOLEAN.value
60-
elif isclass(pydantic_type) and issubclass(pydantic_type, int):
63+
elif isclass(origin_type) and issubclass(origin_type, int):
6164
return PydanticDataType.INTEGER.value
62-
elif isclass(pydantic_type) and issubclass(pydantic_type, float):
65+
elif isclass(origin_type) and issubclass(origin_type, float):
6366
return PydanticDataType.FLOAT.value
64-
elif isclass(pydantic_type) and issubclass(pydantic_type, Enum):
67+
elif isclass(origin_type) and issubclass(origin_type, Enum):
6568
return PydanticDataType.ENUM.value
6669

67-
elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
68-
return format_model_and_field_name(pydantic_type.__name__)
69-
elif get_origin(pydantic_type) is list:
70+
elif isclass(origin_type) and issubclass(origin_type, BaseModel):
71+
return format_model_and_field_name(origin_type.__name__)
72+
elif origin_type is list:
7073
element_type = get_args(pydantic_type)[0]
7174
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
72-
elif get_origin(pydantic_type) is set:
75+
elif origin_type is set:
7376
element_type = get_args(pydantic_type)[0]
7477
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
75-
elif get_origin(pydantic_type) is Union:
78+
elif origin_type is Union:
7679
union_types = get_args(pydantic_type)
7780
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
7881
return f"union-{'-or-'.join(union_rules)}"
79-
elif get_origin(pydantic_type) is Optional:
82+
elif origin_type is Optional:
8083
element_type = get_args(pydantic_type)[0]
8184
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
82-
elif isclass(pydantic_type):
83-
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
84-
elif get_origin(pydantic_type) is dict:
85+
elif isclass(origin_type):
86+
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}"
87+
elif origin_type is dict:
8588
key_type, value_type = get_args(pydantic_type)
8689
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
8790
else:
@@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name):
118121
# Modify this comprehension
119122
members = [
120123
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
121-
for name, param_type in cls.__annotations__.items()
124+
for name, param_type in get_type_hints(cls).items()
122125
if name != "self"
123126
]
124127

@@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type(
297300
field_name = format_model_and_field_name(field_name)
298301
gbnf_type = map_pydantic_type_to_gbnf(field_type)
299302

300-
if isclass(field_type) and issubclass(field_type, BaseModel):
303+
origin_type = get_origin(field_type)
304+
origin_type = field_type if origin_type is None else origin_type
305+
306+
if isclass(origin_type) and issubclass(origin_type, BaseModel):
301307
nested_model_name = format_model_and_field_name(field_type.__name__)
302308
nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
303309
rules.extend(nested_model_rules)
304310
gbnf_type, rules = nested_model_name, rules
305-
elif isclass(field_type) and issubclass(field_type, Enum):
311+
elif isclass(origin_type) and issubclass(origin_type, Enum):
306312
enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes
307313
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
308314
rules.append(enum_rule)
309315
gbnf_type, rules = model_name + "-" + field_name, rules
310-
elif get_origin(field_type) == list: # Array
316+
elif origin_type is list: # Array
311317
element_type = get_args(field_type)[0]
312318
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
313319
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type(
317323
rules.append(array_rule)
318324
gbnf_type, rules = model_name + "-" + field_name, rules
319325

320-
elif get_origin(field_type) == set or field_type == set: # Array
326+
elif origin_type is set: # Array
321327
element_type = get_args(field_type)[0]
322328
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
323329
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type(
371377
gbnf_type = f"{model_name}-{field_name}-optional"
372378
else:
373379
gbnf_type = f"{model_name}-{field_name}-union"
374-
elif isclass(field_type) and issubclass(field_type, str):
380+
elif isclass(origin_type) and issubclass(origin_type, str):
375381
if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
376382
triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
377383
markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
@@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type(
387393
gbnf_type = PydanticDataType.STRING.value
388394

389395
elif (
390-
isclass(field_type)
391-
and issubclass(field_type, float)
396+
isclass(origin_type)
397+
and issubclass(origin_type, float)
392398
and field_info
393399
and hasattr(field_info, "json_schema_extra")
394400
and field_info.json_schema_extra is not None
@@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type(
413419
)
414420

415421
elif (
416-
isclass(field_type)
417-
and issubclass(field_type, int)
422+
isclass(origin_type)
423+
and issubclass(origin_type, int)
418424
and field_info
419425
and hasattr(field_info, "json_schema_extra")
420426
and field_info.json_schema_extra is not None
@@ -462,15 +468,15 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
462468
if not issubclass(model, BaseModel):
463469
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
464470
if hasattr(model, "__annotations__") and model.__annotations__:
465-
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues]
471+
model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
466472
else:
467473
init_signature = inspect.signature(model.__init__)
468474
parameters = init_signature.parameters
469475
model_fields = {name: (param.annotation, param.default) for name, param in parameters.items() if
470476
name != "self"}
471477
else:
472478
# For Pydantic models, use model_fields and check for ellipsis (required fields)
473-
model_fields = model.__annotations__
479+
model_fields = get_type_hints(model)
474480

475481
model_rule_parts = []
476482
nested_rules = []
@@ -706,7 +712,7 @@ def generate_markdown_documentation(
706712
else:
707713
documentation += f" Fields:\n" # noqa: F541
708714
if isclass(model) and issubclass(model, BaseModel):
709-
for name, field_type in model.__annotations__.items():
715+
for name, field_type in get_type_hints(model).items():
710716
# if name == "markdown_code_block":
711717
# continue
712718
if get_origin(field_type) == list:
@@ -754,14 +760,17 @@ def generate_field_markdown(
754760
field_info = model.model_fields.get(field_name)
755761
field_description = field_info.description if field_info and field_info.description else ""
756762

757-
if get_origin(field_type) == list:
763+
origin_type = get_origin(field_type)
764+
origin_type = field_type if origin_type is None else origin_type
765+
766+
if origin_type == list:
758767
element_type = get_args(field_type)[0]
759768
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
760769
if field_description != "":
761770
field_text += ":\n"
762771
else:
763772
field_text += "\n"
764-
elif get_origin(field_type) == Union:
773+
elif origin_type == Union:
765774
element_types = get_args(field_type)
766775
types = []
767776
for element_type in element_types:
@@ -792,9 +801,9 @@ def generate_field_markdown(
792801
example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
793802
field_text += f"{indent} Example: {example_text}\n"
794803

795-
if isclass(field_type) and issubclass(field_type, BaseModel):
804+
if isclass(origin_type) and issubclass(origin_type, BaseModel):
796805
field_text += f"{indent} Details:\n"
797-
for name, type_ in field_type.__annotations__.items():
806+
for name, type_ in get_type_hints(field_type).items():
798807
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
799808

800809
return field_text
@@ -855,7 +864,7 @@ def generate_text_documentation(
855864

856865
if isclass(model) and issubclass(model, BaseModel):
857866
documentation_fields = ""
858-
for name, field_type in model.__annotations__.items():
867+
for name, field_type in get_type_hints(model).items():
859868
# if name == "markdown_code_block":
860869
# continue
861870
if get_origin(field_type) == list:
@@ -948,7 +957,7 @@ def generate_field_text(
948957

949958
if isclass(field_type) and issubclass(field_type, BaseModel):
950959
field_text += f"{indent} Details:\n"
951-
for name, type_ in field_type.__annotations__.items():
960+
for name, type_ in get_type_hints(field_type).items():
952961
field_text += generate_field_text(name, type_, field_type, depth + 2)
953962

954963
return field_text

examples/pydantic_models_to_grammar_examples.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
2020
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
2121
data = response.json()
2222

23+
assert data.get("error") is None, data
24+
2325
print(data["content"])
2426
return data["content"]
2527

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
docstring_parser~=0.15
22
pydantic~=2.6.3
3+
requests

0 commit comments

Comments
 (0)