16
16
from ..language import (
17
17
DirectiveNode ,
18
18
InputValueDefinitionNode ,
19
+ InterfaceTypeDefinitionNode ,
19
20
InterfaceTypeExtensionNode ,
20
21
NamedTypeNode ,
21
22
Node ,
23
+ ObjectTypeDefinitionNode ,
22
24
ObjectTypeExtensionNode ,
23
25
OperationType ,
26
+ SchemaDefinitionNode ,
24
27
SchemaExtensionNode ,
28
+ UnionTypeDefinitionNode ,
25
29
UnionTypeExtensionNode ,
26
30
)
27
31
from .definition import (
45
49
)
46
50
from ..utilities .assert_valid_name import is_valid_name_error
47
51
from ..utilities .type_comparators import is_equal_type , is_type_sub_type_of
48
- from .directives import is_directive , GraphQLDirective , GraphQLDeprecatedDirective
52
+ from .directives import is_directive , GraphQLDeprecatedDirective
49
53
from .introspection import is_introspection_type
50
54
from .schema import GraphQLSchema , assert_schema
51
55
@@ -252,7 +256,7 @@ def validate_fields(
252
256
if not fields :
253
257
self .report_error (
254
258
f"Type { type_ .name } must define one or more fields." ,
255
- get_all_nodes (type_ ) ,
259
+ [ type_ . ast_node , * (type_ . extension_ast_nodes or ())] ,
256
260
)
257
261
258
262
for field_name , field in fields .items ():
@@ -339,7 +343,11 @@ def validate_type_implements_interface(
339
343
self .report_error (
340
344
f"Interface field { iface .name } .{ field_name } "
341
345
f" expected but { type_ .name } does not provide it." ,
342
- [iface_field .ast_node , * get_all_nodes (type_ )],
346
+ [
347
+ iface_field .ast_node ,
348
+ type_ .ast_node ,
349
+ * (type_ .extension_ast_nodes or ()),
350
+ ],
343
351
)
344
352
continue
345
353
@@ -422,7 +430,7 @@ def validate_union_members(self, union: GraphQLUnionType) -> None:
422
430
if not member_types :
423
431
self .report_error (
424
432
f"Union type { union .name } must define one or more member types." ,
425
- get_all_nodes (union ) ,
433
+ [ union . ast_node , * (union . extension_ast_nodes or ())] ,
426
434
)
427
435
428
436
included_type_names : Set [str ] = set ()
@@ -449,7 +457,7 @@ def validate_enum_values(self, enum_type: GraphQLEnumType) -> None:
449
457
if not enum_values :
450
458
self .report_error (
451
459
f"Enum type { enum_type .name } must define one or more values." ,
452
- get_all_nodes (enum_type ) ,
460
+ [ enum_type . ast_node , * (enum_type . extension_ast_nodes or ())] ,
453
461
)
454
462
455
463
for value_name , enum_value in enum_values .items ():
@@ -469,7 +477,7 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
469
477
self .report_error (
470
478
f"Input Object type { input_obj .name } "
471
479
" must define one or more fields." ,
472
- get_all_nodes (input_obj ) ,
480
+ [ input_obj . ast_node , * (input_obj . extension_ast_nodes or ())] ,
473
481
)
474
482
475
483
# Ensure the arguments are valid
@@ -500,12 +508,14 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
500
508
def get_operation_type_node (
501
509
schema : GraphQLSchema , operation : OperationType
502
510
) -> Optional [Node ]:
503
- for extension_node in get_all_nodes (schema ):
504
- operation_types = cast (SchemaExtensionNode , extension_node ).operation_types
505
- if operation_types : # pragma: no cover else
506
- for operation_type in operation_types :
507
- if operation_type .operation == operation :
508
- return operation_type .type
511
+ ast_node : Optional [Union [SchemaDefinitionNode , SchemaExtensionNode ]]
512
+ for ast_node in [schema .ast_node , * (schema .extension_ast_nodes or ())]:
513
+ if ast_node :
514
+ operation_types = ast_node .operation_types
515
+ if operation_types : # pragma: no cover else
516
+ for operation_type in operation_types :
517
+ if operation_type .operation == operation :
518
+ return operation_type .type
509
519
return None
510
520
511
521
@@ -561,55 +571,44 @@ def __call__(self, input_obj: GraphQLInputObjectType) -> None:
561
571
del self .field_path_index_by_type_name [name ]
562
572
563
573
564
- SDLDefinedObject = Union [
565
- GraphQLSchema ,
566
- GraphQLDirective ,
567
- GraphQLInterfaceType ,
568
- GraphQLObjectType ,
569
- GraphQLInputObjectType ,
570
- GraphQLUnionType ,
571
- GraphQLEnumType ,
572
- ]
573
-
574
-
575
- def get_all_nodes (obj : SDLDefinedObject ) -> List [Node ]:
576
- node = obj .ast_node
577
- nodes : List [Node ] = [node ] if node else []
578
- extension_nodes = getattr (obj , "extension_ast_nodes" , None )
579
- if extension_nodes :
580
- nodes .extend (extension_nodes )
581
- return nodes
582
-
583
-
584
574
def get_all_implements_interface_nodes (
585
575
type_ : Union [GraphQLObjectType , GraphQLInterfaceType ], iface : GraphQLInterfaceType
586
576
) -> List [NamedTypeNode ]:
587
577
implements_nodes : List [NamedTypeNode ] = []
588
- for extension_node in get_all_nodes (type_ ):
589
- iface_nodes = cast (
590
- Union [ObjectTypeExtensionNode , InterfaceTypeExtensionNode ], extension_node
591
- ).interfaces
592
- if iface_nodes : # pragma: no cover else
593
- implements_nodes .extend (
594
- iface_node
595
- for iface_node in iface_nodes
596
- if iface_node .name .value == iface .name
597
- )
578
+ ast_node : Optional [
579
+ Union [
580
+ ObjectTypeDefinitionNode ,
581
+ ObjectTypeExtensionNode ,
582
+ InterfaceTypeDefinitionNode ,
583
+ InterfaceTypeExtensionNode ,
584
+ ]
585
+ ]
586
+ for ast_node in [type_ .ast_node , * (type_ .extension_ast_nodes or ())]:
587
+ if ast_node :
588
+ iface_nodes = ast_node .interfaces
589
+ if iface_nodes : # pragma: no cover else
590
+ implements_nodes .extend (
591
+ iface_node
592
+ for iface_node in iface_nodes
593
+ if iface_node .name .value == iface .name
594
+ )
598
595
return implements_nodes
599
596
600
597
601
598
def get_union_member_type_nodes (
602
599
union : GraphQLUnionType , type_name : str
603
600
) -> Optional [List [NamedTypeNode ]]:
604
601
member_type_nodes : List [NamedTypeNode ] = []
605
- for extension_node in get_all_nodes (union ):
606
- type_nodes = cast (UnionTypeExtensionNode , extension_node ).types
607
- if type_nodes : # pragma: no cover else
608
- member_type_nodes .extend (
609
- type_node
610
- for type_node in type_nodes
611
- if type_node .name .value == type_name
612
- )
602
+ ast_node : Optional [Union [UnionTypeDefinitionNode , UnionTypeExtensionNode ]]
603
+ for ast_node in [union .ast_node , * (union .extension_ast_nodes or ())]:
604
+ if ast_node :
605
+ type_nodes = ast_node .types
606
+ if type_nodes : # pragma: no cover else
607
+ member_type_nodes .extend (
608
+ type_node
609
+ for type_node in type_nodes
610
+ if type_node .name .value == type_name
611
+ )
613
612
return member_type_nodes
614
613
615
614
0 commit comments