|
14 | 14 | OperationType,
|
15 | 15 | SelectionSetNode,
|
16 | 16 | VariableDefinitionNode,
|
| 17 | + Visitor, |
17 | 18 | )
|
18 | 19 | from ..type import (
|
19 | 20 | GraphQLArgument,
|
|
43 | 44 | )
|
44 | 45 | from .type_from_ast import type_from_ast
|
45 | 46 |
|
46 |
| -__all__ = ["TypeInfo"] |
| 47 | +__all__ = ["TypeInfo", "TypeInfoVisitor"] |
47 | 48 |
|
48 | 49 |
|
49 | 50 | GetFieldDefType = Callable[
|
@@ -282,3 +283,28 @@ def get_field_def(
|
282 | 283 | parent_type = cast(Union[GraphQLObjectType, GraphQLInterfaceType], parent_type)
|
283 | 284 | return parent_type.fields.get(name)
|
284 | 285 | return None
|
| 286 | + |
| 287 | + |
| 288 | +class TypeInfoVisitor(Visitor): |
| 289 | + """A visitor which maintains a provided TypeInfo.""" |
| 290 | + |
| 291 | + def __init__(self, type_info: "TypeInfo", visitor: Visitor) -> None: |
| 292 | + self.type_info = type_info |
| 293 | + self.visitor = visitor |
| 294 | + |
| 295 | + def enter(self, node, *args): |
| 296 | + self.type_info.enter(node) |
| 297 | + fn = self.visitor.get_visit_fn(node.kind) |
| 298 | + if fn: |
| 299 | + result = fn(self.visitor, node, *args) |
| 300 | + if result is not None: |
| 301 | + self.type_info.leave(node) |
| 302 | + if isinstance(result, Node): |
| 303 | + self.type_info.enter(result) |
| 304 | + return result |
| 305 | + |
| 306 | + def leave(self, node, *args): |
| 307 | + fn = self.visitor.get_visit_fn(node.kind, is_leaving=True) |
| 308 | + result = fn(self.visitor, node, *args) if fn else None |
| 309 | + self.type_info.leave(node) |
| 310 | + return result |
0 commit comments