6
6
from copy import copy
7
7
from enum import Enum
8
8
from inspect import getdoc , isclass
9
- from typing import TYPE_CHECKING , Any , Callable , List , Optional , Union , get_args , get_origin
9
+ from typing import TYPE_CHECKING , Any , Callable , List , Optional , Union , get_args , get_origin , get_type_hints
10
10
11
11
from docstring_parser import parse
12
12
from pydantic import BaseModel , create_model
@@ -53,35 +53,38 @@ class PydanticDataType(Enum):
53
53
54
54
55
55
def map_pydantic_type_to_gbnf (pydantic_type : type [Any ]) -> str :
56
- if isclass (pydantic_type ) and issubclass (pydantic_type , str ):
56
+ origin_type = get_origin (pydantic_type )
57
+ origin_type = pydantic_type if origin_type is None else origin_type
58
+
59
+ if isclass (origin_type ) and issubclass (origin_type , str ):
57
60
return PydanticDataType .STRING .value
58
- elif isclass (pydantic_type ) and issubclass (pydantic_type , bool ):
61
+ elif isclass (origin_type ) and issubclass (origin_type , bool ):
59
62
return PydanticDataType .BOOLEAN .value
60
- elif isclass (pydantic_type ) and issubclass (pydantic_type , int ):
63
+ elif isclass (origin_type ) and issubclass (origin_type , int ):
61
64
return PydanticDataType .INTEGER .value
62
- elif isclass (pydantic_type ) and issubclass (pydantic_type , float ):
65
+ elif isclass (origin_type ) and issubclass (origin_type , float ):
63
66
return PydanticDataType .FLOAT .value
64
- elif isclass (pydantic_type ) and issubclass (pydantic_type , Enum ):
67
+ elif isclass (origin_type ) and issubclass (origin_type , Enum ):
65
68
return PydanticDataType .ENUM .value
66
69
67
- elif isclass (pydantic_type ) and issubclass (pydantic_type , BaseModel ):
68
- return format_model_and_field_name (pydantic_type .__name__ )
69
- elif get_origin ( pydantic_type ) is list :
70
+ elif isclass (origin_type ) and issubclass (origin_type , BaseModel ):
71
+ return format_model_and_field_name (origin_type .__name__ )
72
+ elif origin_type is list :
70
73
element_type = get_args (pydantic_type )[0 ]
71
74
return f"{ map_pydantic_type_to_gbnf (element_type )} -list"
72
- elif get_origin ( pydantic_type ) is set :
75
+ elif origin_type is set :
73
76
element_type = get_args (pydantic_type )[0 ]
74
77
return f"{ map_pydantic_type_to_gbnf (element_type )} -set"
75
- elif get_origin ( pydantic_type ) is Union :
78
+ elif origin_type is Union :
76
79
union_types = get_args (pydantic_type )
77
80
union_rules = [map_pydantic_type_to_gbnf (ut ) for ut in union_types ]
78
81
return f"union-{ '-or-' .join (union_rules )} "
79
- elif get_origin ( pydantic_type ) is Optional :
82
+ elif origin_type is Optional :
80
83
element_type = get_args (pydantic_type )[0 ]
81
84
return f"optional-{ map_pydantic_type_to_gbnf (element_type )} "
82
- elif isclass (pydantic_type ):
83
- return f"{ PydanticDataType .CUSTOM_CLASS .value } -{ format_model_and_field_name (pydantic_type .__name__ )} "
84
- elif get_origin ( pydantic_type ) is dict :
85
+ elif isclass (origin_type ):
86
+ return f"{ PydanticDataType .CUSTOM_CLASS .value } -{ format_model_and_field_name (origin_type .__name__ )} "
87
+ elif origin_type is dict :
85
88
key_type , value_type = get_args (pydantic_type )
86
89
return f"custom-dict-key-type-{ format_model_and_field_name (map_pydantic_type_to_gbnf (key_type ))} -value-type-{ format_model_and_field_name (map_pydantic_type_to_gbnf (value_type ))} "
87
90
else :
@@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name):
118
121
# Modify this comprehension
119
122
members = [
120
123
f' "\\ "{ name } \\ "" ":" { map_pydantic_type_to_gbnf (param_type )} '
121
- for name , param_type in cls . __annotations__ .items ()
124
+ for name , param_type in get_type_hints ( cls ) .items ()
122
125
if name != "self"
123
126
]
124
127
@@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type(
297
300
field_name = format_model_and_field_name (field_name )
298
301
gbnf_type = map_pydantic_type_to_gbnf (field_type )
299
302
300
- if isclass (field_type ) and issubclass (field_type , BaseModel ):
303
+ origin_type = get_origin (field_type )
304
+ origin_type = field_type if origin_type is None else origin_type
305
+
306
+ if isclass (origin_type ) and issubclass (origin_type , BaseModel ):
301
307
nested_model_name = format_model_and_field_name (field_type .__name__ )
302
308
nested_model_rules , _ = generate_gbnf_grammar (field_type , processed_models , created_rules )
303
309
rules .extend (nested_model_rules )
304
310
gbnf_type , rules = nested_model_name , rules
305
- elif isclass (field_type ) and issubclass (field_type , Enum ):
311
+ elif isclass (origin_type ) and issubclass (origin_type , Enum ):
306
312
enum_values = [f'"\\ "{ e .value } \\ ""' for e in field_type ] # Adding escaped quotes
307
313
enum_rule = f"{ model_name } -{ field_name } ::= { ' | ' .join (enum_values )} "
308
314
rules .append (enum_rule )
309
315
gbnf_type , rules = model_name + "-" + field_name , rules
310
- elif get_origin ( field_type ) == list : # Array
316
+ elif origin_type is list : # Array
311
317
element_type = get_args (field_type )[0 ]
312
318
element_rule_name , additional_rules = generate_gbnf_rule_for_type (
313
319
model_name , f"{ field_name } -element" , element_type , is_optional , processed_models , created_rules
@@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type(
317
323
rules .append (array_rule )
318
324
gbnf_type , rules = model_name + "-" + field_name , rules
319
325
320
- elif get_origin ( field_type ) == set or field_type == set : # Array
326
+ elif origin_type is set : # Array
321
327
element_type = get_args (field_type )[0 ]
322
328
element_rule_name , additional_rules = generate_gbnf_rule_for_type (
323
329
model_name , f"{ field_name } -element" , element_type , is_optional , processed_models , created_rules
@@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type(
371
377
gbnf_type = f"{ model_name } -{ field_name } -optional"
372
378
else :
373
379
gbnf_type = f"{ model_name } -{ field_name } -union"
374
- elif isclass (field_type ) and issubclass (field_type , str ):
380
+ elif isclass (origin_type ) and issubclass (origin_type , str ):
375
381
if field_info and hasattr (field_info , "json_schema_extra" ) and field_info .json_schema_extra is not None :
376
382
triple_quoted_string = field_info .json_schema_extra .get ("triple_quoted_string" , False )
377
383
markdown_string = field_info .json_schema_extra .get ("markdown_code_block" , False )
@@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type(
387
393
gbnf_type = PydanticDataType .STRING .value
388
394
389
395
elif (
390
- isclass (field_type )
391
- and issubclass (field_type , float )
396
+ isclass (origin_type )
397
+ and issubclass (origin_type , float )
392
398
and field_info
393
399
and hasattr (field_info , "json_schema_extra" )
394
400
and field_info .json_schema_extra is not None
@@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type(
413
419
)
414
420
415
421
elif (
416
- isclass (field_type )
417
- and issubclass (field_type , int )
422
+ isclass (origin_type )
423
+ and issubclass (origin_type , int )
418
424
and field_info
419
425
and hasattr (field_info , "json_schema_extra" )
420
426
and field_info .json_schema_extra is not None
@@ -462,15 +468,15 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
462
468
if not issubclass (model , BaseModel ):
463
469
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
464
470
if hasattr (model , "__annotations__" ) and model .__annotations__ :
465
- model_fields = {name : (typ , ...) for name , typ in model . __annotations__ . items ()} # pyright: ignore[reportGeneralTypeIssues]
471
+ model_fields = {name : (typ , ...) for name , typ in get_type_hints ( model ). items ()}
466
472
else :
467
473
init_signature = inspect .signature (model .__init__ )
468
474
parameters = init_signature .parameters
469
475
model_fields = {name : (param .annotation , param .default ) for name , param in parameters .items () if
470
476
name != "self" }
471
477
else :
472
478
# For Pydantic models, use model_fields and check for ellipsis (required fields)
473
- model_fields = model . __annotations__
479
+ model_fields = get_type_hints ( model )
474
480
475
481
model_rule_parts = []
476
482
nested_rules = []
@@ -706,7 +712,7 @@ def generate_markdown_documentation(
706
712
else :
707
713
documentation += f" Fields:\n " # noqa: F541
708
714
if isclass (model ) and issubclass (model , BaseModel ):
709
- for name , field_type in model . __annotations__ .items ():
715
+ for name , field_type in get_type_hints ( model ) .items ():
710
716
# if name == "markdown_code_block":
711
717
# continue
712
718
if get_origin (field_type ) == list :
@@ -754,14 +760,17 @@ def generate_field_markdown(
754
760
field_info = model .model_fields .get (field_name )
755
761
field_description = field_info .description if field_info and field_info .description else ""
756
762
757
- if get_origin (field_type ) == list :
763
+ origin_type = get_origin (field_type )
764
+ origin_type = field_type if origin_type is None else origin_type
765
+
766
+ if origin_type == list :
758
767
element_type = get_args (field_type )[0 ]
759
768
field_text = f"{ indent } { field_name } ({ format_model_and_field_name (field_type .__name__ )} of { format_model_and_field_name (element_type .__name__ )} )"
760
769
if field_description != "" :
761
770
field_text += ":\n "
762
771
else :
763
772
field_text += "\n "
764
- elif get_origin ( field_type ) == Union :
773
+ elif origin_type == Union :
765
774
element_types = get_args (field_type )
766
775
types = []
767
776
for element_type in element_types :
@@ -792,9 +801,9 @@ def generate_field_markdown(
792
801
example_text = f"'{ field_example } '" if isinstance (field_example , str ) else field_example
793
802
field_text += f"{ indent } Example: { example_text } \n "
794
803
795
- if isclass (field_type ) and issubclass (field_type , BaseModel ):
804
+ if isclass (origin_type ) and issubclass (origin_type , BaseModel ):
796
805
field_text += f"{ indent } Details:\n "
797
- for name , type_ in field_type . __annotations__ .items ():
806
+ for name , type_ in get_type_hints ( field_type ) .items ():
798
807
field_text += generate_field_markdown (name , type_ , field_type , depth + 2 )
799
808
800
809
return field_text
@@ -855,7 +864,7 @@ def generate_text_documentation(
855
864
856
865
if isclass (model ) and issubclass (model , BaseModel ):
857
866
documentation_fields = ""
858
- for name , field_type in model . __annotations__ .items ():
867
+ for name , field_type in get_type_hints ( model ) .items ():
859
868
# if name == "markdown_code_block":
860
869
# continue
861
870
if get_origin (field_type ) == list :
@@ -948,7 +957,7 @@ def generate_field_text(
948
957
949
958
if isclass (field_type ) and issubclass (field_type , BaseModel ):
950
959
field_text += f"{ indent } Details:\n "
951
- for name , type_ in field_type . __annotations__ .items ():
960
+ for name , type_ in get_type_hints ( field_type ) .items ():
952
961
field_text += generate_field_text (name , type_ , field_type , depth + 2 )
953
962
954
963
return field_text
0 commit comments