Skip to content

Commit a7d5ab3

Browse files
committed
validate_schema: inline get_all_nodes function
Replicates graphql/graphql-js@50b6d97
1 parent 190f5cd commit a7d5ab3

File tree

1 file changed

+49
-50
lines changed

1 file changed

+49
-50
lines changed

src/graphql/type/validate.py

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616
from ..language import (
1717
DirectiveNode,
1818
InputValueDefinitionNode,
19+
InterfaceTypeDefinitionNode,
1920
InterfaceTypeExtensionNode,
2021
NamedTypeNode,
2122
Node,
23+
ObjectTypeDefinitionNode,
2224
ObjectTypeExtensionNode,
2325
OperationType,
26+
SchemaDefinitionNode,
2427
SchemaExtensionNode,
28+
UnionTypeDefinitionNode,
2529
UnionTypeExtensionNode,
2630
)
2731
from .definition import (
@@ -45,7 +49,7 @@
4549
)
4650
from ..utilities.assert_valid_name import is_valid_name_error
4751
from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of
48-
from .directives import is_directive, GraphQLDirective, GraphQLDeprecatedDirective
52+
from .directives import is_directive, GraphQLDeprecatedDirective
4953
from .introspection import is_introspection_type
5054
from .schema import GraphQLSchema, assert_schema
5155

@@ -252,7 +256,7 @@ def validate_fields(
252256
if not fields:
253257
self.report_error(
254258
f"Type {type_.name} must define one or more fields.",
255-
get_all_nodes(type_),
259+
[type_.ast_node, *(type_.extension_ast_nodes or ())],
256260
)
257261

258262
for field_name, field in fields.items():
@@ -339,7 +343,11 @@ def validate_type_implements_interface(
339343
self.report_error(
340344
f"Interface field {iface.name}.{field_name}"
341345
f" expected but {type_.name} does not provide it.",
342-
[iface_field.ast_node, *get_all_nodes(type_)],
346+
[
347+
iface_field.ast_node,
348+
type_.ast_node,
349+
*(type_.extension_ast_nodes or ()),
350+
],
343351
)
344352
continue
345353

@@ -422,7 +430,7 @@ def validate_union_members(self, union: GraphQLUnionType) -> None:
422430
if not member_types:
423431
self.report_error(
424432
f"Union type {union.name} must define one or more member types.",
425-
get_all_nodes(union),
433+
[union.ast_node, *(union.extension_ast_nodes or ())],
426434
)
427435

428436
included_type_names: Set[str] = set()
@@ -449,7 +457,7 @@ def validate_enum_values(self, enum_type: GraphQLEnumType) -> None:
449457
if not enum_values:
450458
self.report_error(
451459
f"Enum type {enum_type.name} must define one or more values.",
452-
get_all_nodes(enum_type),
460+
[enum_type.ast_node, *(enum_type.extension_ast_nodes or ())],
453461
)
454462

455463
for value_name, enum_value in enum_values.items():
@@ -469,7 +477,7 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
469477
self.report_error(
470478
f"Input Object type {input_obj.name}"
471479
" must define one or more fields.",
472-
get_all_nodes(input_obj),
480+
[input_obj.ast_node, *(input_obj.extension_ast_nodes or ())],
473481
)
474482

475483
# Ensure the arguments are valid
@@ -500,12 +508,14 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
500508
def get_operation_type_node(
501509
schema: GraphQLSchema, operation: OperationType
502510
) -> Optional[Node]:
503-
for extension_node in get_all_nodes(schema):
504-
operation_types = cast(SchemaExtensionNode, extension_node).operation_types
505-
if operation_types: # pragma: no cover else
506-
for operation_type in operation_types:
507-
if operation_type.operation == operation:
508-
return operation_type.type
511+
ast_node: Optional[Union[SchemaDefinitionNode, SchemaExtensionNode]]
512+
for ast_node in [schema.ast_node, *(schema.extension_ast_nodes or ())]:
513+
if ast_node:
514+
operation_types = ast_node.operation_types
515+
if operation_types: # pragma: no cover else
516+
for operation_type in operation_types:
517+
if operation_type.operation == operation:
518+
return operation_type.type
509519
return None
510520

511521

@@ -561,55 +571,44 @@ def __call__(self, input_obj: GraphQLInputObjectType) -> None:
561571
del self.field_path_index_by_type_name[name]
562572

563573

564-
SDLDefinedObject = Union[
565-
GraphQLSchema,
566-
GraphQLDirective,
567-
GraphQLInterfaceType,
568-
GraphQLObjectType,
569-
GraphQLInputObjectType,
570-
GraphQLUnionType,
571-
GraphQLEnumType,
572-
]
573-
574-
575-
def get_all_nodes(obj: SDLDefinedObject) -> List[Node]:
576-
node = obj.ast_node
577-
nodes: List[Node] = [node] if node else []
578-
extension_nodes = getattr(obj, "extension_ast_nodes", None)
579-
if extension_nodes:
580-
nodes.extend(extension_nodes)
581-
return nodes
582-
583-
584574
def get_all_implements_interface_nodes(
585575
type_: Union[GraphQLObjectType, GraphQLInterfaceType], iface: GraphQLInterfaceType
586576
) -> List[NamedTypeNode]:
587577
implements_nodes: List[NamedTypeNode] = []
588-
for extension_node in get_all_nodes(type_):
589-
iface_nodes = cast(
590-
Union[ObjectTypeExtensionNode, InterfaceTypeExtensionNode], extension_node
591-
).interfaces
592-
if iface_nodes: # pragma: no cover else
593-
implements_nodes.extend(
594-
iface_node
595-
for iface_node in iface_nodes
596-
if iface_node.name.value == iface.name
597-
)
578+
ast_node: Optional[
579+
Union[
580+
ObjectTypeDefinitionNode,
581+
ObjectTypeExtensionNode,
582+
InterfaceTypeDefinitionNode,
583+
InterfaceTypeExtensionNode,
584+
]
585+
]
586+
for ast_node in [type_.ast_node, *(type_.extension_ast_nodes or ())]:
587+
if ast_node:
588+
iface_nodes = ast_node.interfaces
589+
if iface_nodes: # pragma: no cover else
590+
implements_nodes.extend(
591+
iface_node
592+
for iface_node in iface_nodes
593+
if iface_node.name.value == iface.name
594+
)
598595
return implements_nodes
599596

600597

601598
def get_union_member_type_nodes(
602599
union: GraphQLUnionType, type_name: str
603600
) -> Optional[List[NamedTypeNode]]:
604601
member_type_nodes: List[NamedTypeNode] = []
605-
for extension_node in get_all_nodes(union):
606-
type_nodes = cast(UnionTypeExtensionNode, extension_node).types
607-
if type_nodes: # pragma: no cover else
608-
member_type_nodes.extend(
609-
type_node
610-
for type_node in type_nodes
611-
if type_node.name.value == type_name
612-
)
602+
ast_node: Optional[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]]
603+
for ast_node in [union.ast_node, *(union.extension_ast_nodes or ())]:
604+
if ast_node:
605+
type_nodes = ast_node.types
606+
if type_nodes: # pragma: no cover else
607+
member_type_nodes.extend(
608+
type_node
609+
for type_node in type_nodes
610+
if type_node.name.value == type_name
611+
)
613612
return member_type_nodes
614613

615614

0 commit comments

Comments
 (0)