diff --git a/definition.go b/definition.go index f26f3002..3b5eb8f9 100644 --- a/definition.go +++ b/definition.go @@ -120,6 +120,9 @@ var _ Output = (*NonNull)(nil) // Composite interface for types that may describe the parent context of a selection set. type Composite interface { Name() string + Description() string + String() string + Error() error } var _ Composite = (*Object)(nil) diff --git a/directives.go b/directives.go index 676eab75..7ed839f7 100644 --- a/directives.go +++ b/directives.go @@ -1,6 +1,7 @@ package graphql const ( + // Operations DirectiveLocationQuery = "QUERY" DirectiveLocationMutation = "MUTATION" DirectiveLocationSubscription = "SUBSCRIPTION" @@ -8,8 +9,31 @@ const ( DirectiveLocationFragmentDefinition = "FRAGMENT_DEFINITION" DirectiveLocationFragmentSpread = "FRAGMENT_SPREAD" DirectiveLocationInlineFragment = "INLINE_FRAGMENT" + + // Schema Definitions + DirectiveLocationSchema = "SCHEMA" + DirectiveLocationScalar = "SCALAR" + DirectiveLocationObject = "OBJECT" + DirectiveLocationFieldDefinition = "FIELD_DEFINITION" + DirectiveLocationArgumentDefinition = "ARGUMENT_DEFINITION" + DirectiveLocationInterface = "INTERFACE" + DirectiveLocationUnion = "UNION" + DirectiveLocationEnum = "ENUM" + DirectiveLocationEnumValue = "ENUM_VALUE" + DirectiveLocationInputObject = "INPUT_OBJECT" + DirectiveLocationInputFieldDefinition = "INPUT_FIELD_DEFINITION" ) +// DefaultDeprecationReason Constant string used for default reason for a deprecation. +const DefaultDeprecationReason = "No longer supported" + +// SpecifiedRules The full list of specified directives. +var SpecifiedDirectives = []*Directive{ + IncludeDirective, + SkipDirective, + DeprecatedDirective, +} + // Directive structs are used by the GraphQL runtime as a way of modifying execution // behavior. Type system creators will usually not create these directly. type Directive struct { @@ -76,7 +100,7 @@ func NewDirective(config DirectiveConfig) *Directive { return dir } -// IncludeDirective is used to conditionally include fields or fragments +// IncludeDirective is used to conditionally include fields or fragments. var IncludeDirective = NewDirective(DirectiveConfig{ Name: "include", Description: "Directs the executor to include this field or fragment only when " + @@ -94,7 +118,7 @@ var IncludeDirective = NewDirective(DirectiveConfig{ }, }) -// SkipDirective Used to conditionally skip (exclude) fields or fragments +// SkipDirective Used to conditionally skip (exclude) fields or fragments. var SkipDirective = NewDirective(DirectiveConfig{ Name: "skip", Description: "Directs the executor to skip this field or fragment when the `if` " + @@ -111,3 +135,22 @@ var SkipDirective = NewDirective(DirectiveConfig{ DirectiveLocationInlineFragment, }, }) + +// DeprecatedDirective Used to declare element of a GraphQL schema as deprecated. +var DeprecatedDirective = NewDirective(DirectiveConfig{ + Name: "deprecated", + Description: "Marks an element of a GraphQL schema as no longer supported.", + Args: FieldConfigArgument{ + "reason": &ArgumentConfig{ + Type: String, + Description: "Explains why this element was deprecated, usually also including a " + + "suggestion for how to access supported similar data. Formatted" + + "in [Markdown](https://daringfireball.net/projects/markdown/).", + DefaultValue: DefaultDeprecationReason, + }, + }, + Locations: []string{ + DirectiveLocationFieldDefinition, + DirectiveLocationEnumValue, + }, +}) diff --git a/directives_test.go b/directives_test.go index 164ecd7b..3443ccb4 100644 --- a/directives_test.go +++ b/directives_test.go @@ -342,9 +342,6 @@ func TestDirectivesWorksOnInlineFragmentIfFalseOmitsInlineFragment(t *testing.T) b } } - fragment Frag on TestType { - b - } ` expected := &graphql.Result{ Data: map[string]interface{}{ @@ -368,9 +365,6 @@ func TestDirectivesWorksOnInlineFragmentIfTrueIncludesInlineFragment(t *testing. b } } - fragment Frag on TestType { - b - } ` expected := &graphql.Result{ Data: map[string]interface{}{ @@ -395,9 +389,6 @@ func TestDirectivesWorksOnInlineFragmentUnlessFalseIncludesInlineFragment(t *tes b } } - fragment Frag on TestType { - b - } ` expected := &graphql.Result{ Data: map[string]interface{}{ @@ -422,9 +413,6 @@ func TestDirectivesWorksOnInlineFragmentUnlessTrueIncludesInlineFragment(t *test b } } - fragment Frag on TestType { - b - } ` expected := &graphql.Result{ Data: map[string]interface{}{ diff --git a/executor.go b/executor.go index bf093f94..b11b1a27 100644 --- a/executor.go +++ b/executor.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "strings" + "sync" "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" @@ -48,8 +49,8 @@ func Execute(p ExecuteParams) (result *Result) { if r, ok := r.(error); ok { err = gqlerrors.FormatError(r) } - exeContext.Errors = append(exeContext.Errors, gqlerrors.FormatError(err)) - result.Errors = exeContext.Errors + exeContext.AppendError(err) + result.Errors = exeContext.Errors() } }() @@ -76,12 +77,39 @@ type ExecutionContext struct { Root interface{} Operation ast.Definition VariableValues map[string]interface{} - Errors []gqlerrors.FormattedError Context context.Context + + errLock sync.RWMutex + errors []gqlerrors.FormattedError +} + +func (eCtx *ExecutionContext) AppendError(errs ...error) { + formattedErrors := []gqlerrors.FormattedError{} + for _, err := range errs { + formattedErrors = append(formattedErrors, gqlerrors.FormatError(err)) + } + eCtx.errLock.Lock() + eCtx.errors = append(eCtx.errors, formattedErrors...) + eCtx.errLock.Unlock() +} + +func (eCtx *ExecutionContext) Errors() (res []gqlerrors.FormattedError) { + eCtx.errLock.RLock() + res = eCtx.errors + eCtx.errLock.RUnlock() + return res +} + +func (eCtx *ExecutionContext) SetErrors(errors []gqlerrors.FormattedError) { + eCtx.errLock.Lock() + eCtx.errors = errors + eCtx.errLock.Unlock() } func buildExecutionContext(p BuildExecutionCtxParams) (*ExecutionContext, error) { - eCtx := &ExecutionContext{} + eCtx := &ExecutionContext{ + errLock: sync.RWMutex{}, + } var operation *ast.OperationDefinition fragments := map[string]ast.Definition{} @@ -122,7 +150,7 @@ func buildExecutionContext(p BuildExecutionCtxParams) (*ExecutionContext, error) eCtx.Root = p.Root eCtx.Operation = operation eCtx.VariableValues = variableValues - eCtx.Errors = p.Errors + eCtx.SetErrors(p.Errors) eCtx.Context = p.Context return eCtx, nil } @@ -233,7 +261,7 @@ func executeFieldsSerially(p ExecuteFieldsParams) *Result { return &Result{ Data: finalResults, - Errors: p.ExecutionContext.Errors, + Errors: p.ExecutionContext.Errors(), } } @@ -246,18 +274,42 @@ func executeFields(p ExecuteFieldsParams) *Result { p.Fields = map[string][]*ast.Field{} } - finalResults := map[string]interface{}{} + // concurrently resolve fields + wg := sync.WaitGroup{} + mtx := sync.Mutex{} + finalResults := make(map[string]interface{}, len(p.Fields)) + panics := make(chan interface{}, len(p.Fields)) for responseName, fieldASTs := range p.Fields { - resolved, state := resolveField(p.ExecutionContext, p.ParentType, p.Source, fieldASTs) - if state.hasNoFieldDefs { - continue - } - finalResults[responseName] = resolved + wg.Add(1) + go func(responseName string, fieldASTs []*ast.Field) { + defer func() { + if r := recover(); r != nil { + panics <- r + } + wg.Done() + }() + resolved, state := resolveField(p.ExecutionContext, p.ParentType, p.Source, fieldASTs) + if state.hasNoFieldDefs { + return + } + mtx.Lock() + finalResults[responseName] = resolved + mtx.Unlock() + }(responseName, fieldASTs) + } + + // wait for all routines to complete and then perform clean up + wg.Wait() + close(panics) + + // re-panic if a goroutine panicked + for p := range panics { + panic(p) } return &Result{ Data: finalResults, - Errors: p.ExecutionContext.Errors, + Errors: p.ExecutionContext.Errors(), } } @@ -266,6 +318,7 @@ type CollectFieldsParams struct { RuntimeType *Object // previously known as OperationType SelectionSet *ast.SelectionSet Fields map[string][]*ast.Field + FieldOrder []string VisitedFragmentNames map[string]bool } @@ -480,7 +533,6 @@ func resolveField(eCtx *ExecutionContext, parentType *Object, source interface{} var returnType Output defer func() (interface{}, resolveFieldResultState) { if r := recover(); r != nil { - var err error if r, ok := r.(string); ok { err = NewLocatedError( @@ -495,7 +547,7 @@ func resolveField(eCtx *ExecutionContext, parentType *Object, source interface{} if _, ok := returnType.(*NonNull); ok { panic(gqlerrors.FormatError(err)) } - eCtx.Errors = append(eCtx.Errors, gqlerrors.FormatError(err)) + eCtx.AppendError(err) return result, resultState } return result, resultState @@ -545,7 +597,8 @@ func resolveField(eCtx *ExecutionContext, parentType *Object, source interface{} }) if resolveFnError != nil { - panic(gqlerrors.FormatError(resolveFnError)) + eCtx.AppendError(resolveFnError) + return nil, resultState } completed := completeValueCatchingError(eCtx, returnType, fieldASTs, info, result) @@ -561,7 +614,7 @@ func completeValueCatchingError(eCtx *ExecutionContext, returnType Type, fieldAS panic(r) } if err, ok := r.(gqlerrors.FormattedError); ok { - eCtx.Errors = append(eCtx.Errors, err) + eCtx.AppendError(err) } return completed } @@ -664,7 +717,9 @@ func completeAbstractValue(eCtx *ExecutionContext, returnType Abstract, fieldAST } err := invariant(runtimeType != nil, - fmt.Sprintf(`Could not determine runtime type of value "%v" for field %v.%v.`, result, info.ParentType, info.FieldName), + fmt.Sprintf(`Abstract type %v must resolve to an Object type at runtime `+ + `for field %v.%v with value "%v", received "%v".`, + returnType, info.ParentType, info.FieldName, result, runtimeType), ) if err != nil { panic(err) @@ -755,13 +810,34 @@ func completeListValue(eCtx *ExecutionContext, returnType *List, fieldASTs []*as panic(gqlerrors.FormatError(err)) } + // concurrently resolve list elements itemType := returnType.OfType - completedResults := []interface{}{} + wg := sync.WaitGroup{} + completedResults := make([]interface{}, resultVal.Len()) + panics := make(chan interface{}, resultVal.Len()) for i := 0; i < resultVal.Len(); i++ { - val := resultVal.Index(i).Interface() - completedItem := completeValueCatchingError(eCtx, itemType, fieldASTs, info, val) - completedResults = append(completedResults, completedItem) + wg.Add(1) + go func(j int) { + defer func() { + if r := recover(); r != nil { + panics <- r + } + wg.Done() + }() + val := resultVal.Index(j).Interface() + completedResults[j] = completeValueCatchingError(eCtx, itemType, fieldASTs, info, val) + }(i) + } + + // wait for all routines to complete and then perform clean up + wg.Wait() + close(panics) + + // re-panic if a goroutine panicked + for p := range panics { + panic(p) } + return completedResults } diff --git a/introspection.go b/introspection.go index e837be2e..8415bd07 100644 --- a/introspection.go +++ b/introspection.go @@ -19,23 +19,44 @@ const ( TypeKindNonNull = "NON_NULL" ) -var directiveType *Object -var schemaType *Object -var typeType *Object -var fieldType *Object -var inputValueType *Object -var enumValueType *Object +// SchemaType is type definition for __Schema +var SchemaType *Object -var typeKindEnum *Enum -var directiveLocationEnum *Enum +// DirectiveType is type definition for __Directive +var DirectiveType *Object +// TypeType is type definition for __Type +var TypeType *Object + +// FieldType is type definition for __Field +var FieldType *Object + +// InputValueType is type definition for __InputValue +var InputValueType *Object + +// EnumValueType is type definition for __EnumValue +var EnumValueType *Object + +// TypeKindEnumType is type definition for __TypeKind +var TypeKindEnumType *Enum + +// DirectiveLocationEnumType is type definition for __DirectiveLocation +var DirectiveLocationEnumType *Enum + +// Meta-field definitions. + +// SchemaMetaFieldDef Meta field definition for Schema var SchemaMetaFieldDef *FieldDefinition + +// TypeMetaFieldDef Meta field definition for types var TypeMetaFieldDef *FieldDefinition + +// TypeNameMetaFieldDef Meta field definition for type names var TypeNameMetaFieldDef *FieldDefinition func init() { - typeKindEnum = NewEnum(EnumConfig{ + TypeKindEnumType = NewEnum(EnumConfig{ Name: "__TypeKind", Description: "An enum describing what kind of type a given `__Type` is", Values: EnumValueConfigMap{ @@ -81,7 +102,7 @@ func init() { }, }) - directiveLocationEnum = NewEnum(EnumConfig{ + DirectiveLocationEnumType = NewEnum(EnumConfig{ Name: "__DirectiveLocation", Description: "A Directive can be adjacent to many parts of the GraphQL language, a " + "__DirectiveLocation describes one such possible adjacencies.", @@ -114,11 +135,55 @@ func init() { Value: DirectiveLocationInlineFragment, Description: "Location adjacent to an inline fragment.", }, + "SCHEMA": &EnumValueConfig{ + Value: DirectiveLocationSchema, + Description: "Location adjacent to a schema definition.", + }, + "SCALAR": &EnumValueConfig{ + Value: DirectiveLocationScalar, + Description: "Location adjacent to a scalar definition.", + }, + "OBJECT": &EnumValueConfig{ + Value: DirectiveLocationObject, + Description: "Location adjacent to a object definition.", + }, + "FIELD_DEFINITION": &EnumValueConfig{ + Value: DirectiveLocationFieldDefinition, + Description: "Location adjacent to a field definition.", + }, + "ARGUMENT_DEFINITION": &EnumValueConfig{ + Value: DirectiveLocationArgumentDefinition, + Description: "Location adjacent to an argument definition.", + }, + "INTERFACE": &EnumValueConfig{ + Value: DirectiveLocationInterface, + Description: "Location adjacent to an interface definition.", + }, + "UNION": &EnumValueConfig{ + Value: DirectiveLocationUnion, + Description: "Location adjacent to a union definition.", + }, + "ENUM": &EnumValueConfig{ + Value: DirectiveLocationEnum, + Description: "Location adjacent to an enum definition.", + }, + "ENUM_VALUE": &EnumValueConfig{ + Value: DirectiveLocationEnumValue, + Description: "Location adjacent to an enum value definition.", + }, + "INPUT_OBJECT": &EnumValueConfig{ + Value: DirectiveLocationInputObject, + Description: "Location adjacent to an input object type definition.", + }, + "INPUT_FIELD_DEFINITION": &EnumValueConfig{ + Value: DirectiveLocationInputFieldDefinition, + Description: "Location adjacent to an input object field definition.", + }, }, }) // Note: some fields (for e.g "fields", "interfaces") are defined later due to cyclic reference - typeType = NewObject(ObjectConfig{ + TypeType = NewObject(ObjectConfig{ Name: "__Type", Description: "The fundamental unit of any GraphQL Schema is the type. There are " + "many kinds of types in GraphQL as represented by the `__TypeKind` enum." + @@ -131,7 +196,7 @@ func init() { Fields: Fields{ "kind": &Field{ - Type: NewNonNull(typeKindEnum), + Type: NewNonNull(TypeKindEnumType), Resolve: func(p ResolveParams) (interface{}, error) { switch p.Source.(type) { case *Scalar: @@ -169,7 +234,7 @@ func init() { }, }) - inputValueType = NewObject(ObjectConfig{ + InputValueType = NewObject(ObjectConfig{ Name: "__InputValue", Description: "Arguments provided to Fields or Directives and the input fields of an " + "InputObject are represented as Input Values which describe their type " + @@ -182,7 +247,7 @@ func init() { Type: String, }, "type": &Field{ - Type: NewNonNull(typeType), + Type: NewNonNull(TypeType), }, "defaultValue": &Field{ Type: String, @@ -212,7 +277,7 @@ func init() { }, }) - fieldType = NewObject(ObjectConfig{ + FieldType = NewObject(ObjectConfig{ Name: "__Field", Description: "Object and Interface types are described by a list of Fields, each of " + "which has a name, potentially a list of arguments, and a return type.", @@ -224,7 +289,7 @@ func init() { Type: String, }, "args": &Field{ - Type: NewNonNull(NewList(NewNonNull(inputValueType))), + Type: NewNonNull(NewList(NewNonNull(InputValueType))), Resolve: func(p ResolveParams) (interface{}, error) { if field, ok := p.Source.(*FieldDefinition); ok { return field.Args, nil @@ -233,7 +298,7 @@ func init() { }, }, "type": &Field{ - Type: NewNonNull(typeType), + Type: NewNonNull(TypeType), }, "isDeprecated": &Field{ Type: NewNonNull(Boolean), @@ -250,7 +315,7 @@ func init() { }, }) - directiveType = NewObject(ObjectConfig{ + DirectiveType = NewObject(ObjectConfig{ Name: "__Directive", Description: "A Directive provides a way to describe alternate runtime execution and " + "type validation behavior in a GraphQL document. " + @@ -267,12 +332,12 @@ func init() { }, "locations": &Field{ Type: NewNonNull(NewList( - NewNonNull(directiveLocationEnum), + NewNonNull(DirectiveLocationEnumType), )), }, "args": &Field{ Type: NewNonNull(NewList( - NewNonNull(inputValueType), + NewNonNull(InputValueType), )), }, // NOTE: the following three fields are deprecated and are no longer part @@ -335,7 +400,7 @@ func init() { }, }) - schemaType = NewObject(ObjectConfig{ + SchemaType = NewObject(ObjectConfig{ Name: "__Schema", Description: `A GraphQL Schema defines the capabilities of a GraphQL server. ` + `It exposes all available types and directives on the server, as well as ` + @@ -344,7 +409,7 @@ func init() { "types": &Field{ Description: "A list of all types supported by this server.", Type: NewNonNull(NewList( - NewNonNull(typeType), + NewNonNull(TypeType), )), Resolve: func(p ResolveParams) (interface{}, error) { if schema, ok := p.Source.(Schema); ok { @@ -359,7 +424,7 @@ func init() { }, "queryType": &Field{ Description: "The type that query operations will be rooted at.", - Type: NewNonNull(typeType), + Type: NewNonNull(TypeType), Resolve: func(p ResolveParams) (interface{}, error) { if schema, ok := p.Source.(Schema); ok { return schema.QueryType(), nil @@ -370,7 +435,7 @@ func init() { "mutationType": &Field{ Description: `If this server supports mutation, the type that ` + `mutation operations will be rooted at.`, - Type: typeType, + Type: TypeType, Resolve: func(p ResolveParams) (interface{}, error) { if schema, ok := p.Source.(Schema); ok { if schema.MutationType() != nil { @@ -383,7 +448,7 @@ func init() { "subscriptionType": &Field{ Description: `If this server supports subscription, the type that ` + `subscription operations will be rooted at.`, - Type: typeType, + Type: TypeType, Resolve: func(p ResolveParams) (interface{}, error) { if schema, ok := p.Source.(Schema); ok { if schema.SubscriptionType() != nil { @@ -396,7 +461,7 @@ func init() { "directives": &Field{ Description: `A list of all directives supported by this server.`, Type: NewNonNull(NewList( - NewNonNull(directiveType), + NewNonNull(DirectiveType), )), Resolve: func(p ResolveParams) (interface{}, error) { if schema, ok := p.Source.(Schema); ok { @@ -408,7 +473,7 @@ func init() { }, }) - enumValueType = NewObject(ObjectConfig{ + EnumValueType = NewObject(ObjectConfig{ Name: "__EnumValue", Description: "One possible value for a given Enum. Enum values are unique values, not " + "a placeholder for a string or numeric value. However an Enum value is " + @@ -437,8 +502,8 @@ func init() { // Again, adding field configs to __Type that have cyclic reference here // because golang don't like them too much during init/compile-time - typeType.AddFieldConfig("fields", &Field{ - Type: NewList(NewNonNull(fieldType)), + TypeType.AddFieldConfig("fields", &Field{ + Type: NewList(NewNonNull(FieldType)), Args: FieldConfigArgument{ "includeDeprecated": &ArgumentConfig{ Type: Boolean, @@ -476,8 +541,8 @@ func init() { return nil, nil }, }) - typeType.AddFieldConfig("interfaces", &Field{ - Type: NewList(NewNonNull(typeType)), + TypeType.AddFieldConfig("interfaces", &Field{ + Type: NewList(NewNonNull(TypeType)), Resolve: func(p ResolveParams) (interface{}, error) { switch ttype := p.Source.(type) { case *Object: @@ -486,8 +551,8 @@ func init() { return nil, nil }, }) - typeType.AddFieldConfig("possibleTypes", &Field{ - Type: NewList(NewNonNull(typeType)), + TypeType.AddFieldConfig("possibleTypes", &Field{ + Type: NewList(NewNonNull(TypeType)), Resolve: func(p ResolveParams) (interface{}, error) { switch ttype := p.Source.(type) { case *Interface: @@ -498,8 +563,8 @@ func init() { return nil, nil }, }) - typeType.AddFieldConfig("enumValues", &Field{ - Type: NewList(NewNonNull(enumValueType)), + TypeType.AddFieldConfig("enumValues", &Field{ + Type: NewList(NewNonNull(EnumValueType)), Args: FieldConfigArgument{ "includeDeprecated": &ArgumentConfig{ Type: Boolean, @@ -525,8 +590,8 @@ func init() { return nil, nil }, }) - typeType.AddFieldConfig("inputFields", &Field{ - Type: NewList(NewNonNull(inputValueType)), + TypeType.AddFieldConfig("inputFields", &Field{ + Type: NewList(NewNonNull(InputValueType)), Resolve: func(p ResolveParams) (interface{}, error) { switch ttype := p.Source.(type) { case *InputObject: @@ -539,15 +604,15 @@ func init() { return nil, nil }, }) - typeType.AddFieldConfig("ofType", &Field{ - Type: typeType, + TypeType.AddFieldConfig("ofType", &Field{ + Type: TypeType, }) // Note that these are FieldDefinition and not FieldConfig, - // so the format for args is different. d + // so the format for args is different. SchemaMetaFieldDef = &FieldDefinition{ Name: "__schema", - Type: NewNonNull(schemaType), + Type: NewNonNull(SchemaType), Description: "Access the current type schema of this server.", Args: []*Argument{}, Resolve: func(p ResolveParams) (interface{}, error) { @@ -556,7 +621,7 @@ func init() { } TypeMetaFieldDef = &FieldDefinition{ Name: "__type", - Type: typeType, + Type: TypeType, Description: "Request the type information of a single type.", Args: []*Argument{ { diff --git a/language/ast/type_definitions.go b/language/ast/type_definitions.go index dd7940c9..a5132dd4 100644 --- a/language/ast/type_definitions.go +++ b/language/ast/type_definitions.go @@ -36,6 +36,7 @@ var _ TypeSystemDefinition = (*DirectiveDefinition)(nil) type SchemaDefinition struct { Kind string Loc *Location + Directives []*Directive OperationTypes []*OperationTypeDefinition } @@ -46,6 +47,7 @@ func NewSchemaDefinition(def *SchemaDefinition) *SchemaDefinition { return &SchemaDefinition{ Kind: kinds.SchemaDefinition, Loc: def.Loc, + Directives: def.Directives, OperationTypes: def.OperationTypes, } } @@ -100,9 +102,10 @@ func (def *OperationTypeDefinition) GetLoc() *Location { // ScalarDefinition implements Node, Definition type ScalarDefinition struct { - Kind string - Loc *Location - Name *Name + Kind string + Loc *Location + Name *Name + Directives []*Directive } func NewScalarDefinition(def *ScalarDefinition) *ScalarDefinition { @@ -110,9 +113,10 @@ func NewScalarDefinition(def *ScalarDefinition) *ScalarDefinition { def = &ScalarDefinition{} } return &ScalarDefinition{ - Kind: kinds.ScalarDefinition, - Loc: def.Loc, - Name: def.Name, + Kind: kinds.ScalarDefinition, + Loc: def.Loc, + Name: def.Name, + Directives: def.Directives, } } @@ -146,6 +150,7 @@ type ObjectDefinition struct { Loc *Location Name *Name Interfaces []*Named + Directives []*Directive Fields []*FieldDefinition } @@ -158,6 +163,7 @@ func NewObjectDefinition(def *ObjectDefinition) *ObjectDefinition { Loc: def.Loc, Name: def.Name, Interfaces: def.Interfaces, + Directives: def.Directives, Fields: def.Fields, } } @@ -188,11 +194,12 @@ func (def *ObjectDefinition) GetOperation() string { // FieldDefinition implements Node type FieldDefinition struct { - Kind string - Loc *Location - Name *Name - Arguments []*InputValueDefinition - Type Type + Kind string + Loc *Location + Name *Name + Arguments []*InputValueDefinition + Type Type + Directives []*Directive } func NewFieldDefinition(def *FieldDefinition) *FieldDefinition { @@ -200,11 +207,12 @@ func NewFieldDefinition(def *FieldDefinition) *FieldDefinition { def = &FieldDefinition{} } return &FieldDefinition{ - Kind: kinds.FieldDefinition, - Loc: def.Loc, - Name: def.Name, - Arguments: def.Arguments, - Type: def.Type, + Kind: kinds.FieldDefinition, + Loc: def.Loc, + Name: def.Name, + Arguments: def.Arguments, + Type: def.Type, + Directives: def.Directives, } } @@ -223,6 +231,7 @@ type InputValueDefinition struct { Name *Name Type Type DefaultValue Value + Directives []*Directive } func NewInputValueDefinition(def *InputValueDefinition) *InputValueDefinition { @@ -235,6 +244,7 @@ func NewInputValueDefinition(def *InputValueDefinition) *InputValueDefinition { Name: def.Name, Type: def.Type, DefaultValue: def.DefaultValue, + Directives: def.Directives, } } @@ -248,10 +258,11 @@ func (def *InputValueDefinition) GetLoc() *Location { // InterfaceDefinition implements Node, Definition type InterfaceDefinition struct { - Kind string - Loc *Location - Name *Name - Fields []*FieldDefinition + Kind string + Loc *Location + Name *Name + Directives []*Directive + Fields []*FieldDefinition } func NewInterfaceDefinition(def *InterfaceDefinition) *InterfaceDefinition { @@ -259,10 +270,11 @@ func NewInterfaceDefinition(def *InterfaceDefinition) *InterfaceDefinition { def = &InterfaceDefinition{} } return &InterfaceDefinition{ - Kind: kinds.InterfaceDefinition, - Loc: def.Loc, - Name: def.Name, - Fields: def.Fields, + Kind: kinds.InterfaceDefinition, + Loc: def.Loc, + Name: def.Name, + Directives: def.Directives, + Fields: def.Fields, } } @@ -292,10 +304,11 @@ func (def *InterfaceDefinition) GetOperation() string { // UnionDefinition implements Node, Definition type UnionDefinition struct { - Kind string - Loc *Location - Name *Name - Types []*Named + Kind string + Loc *Location + Name *Name + Directives []*Directive + Types []*Named } func NewUnionDefinition(def *UnionDefinition) *UnionDefinition { @@ -303,10 +316,11 @@ func NewUnionDefinition(def *UnionDefinition) *UnionDefinition { def = &UnionDefinition{} } return &UnionDefinition{ - Kind: kinds.UnionDefinition, - Loc: def.Loc, - Name: def.Name, - Types: def.Types, + Kind: kinds.UnionDefinition, + Loc: def.Loc, + Name: def.Name, + Directives: def.Directives, + Types: def.Types, } } @@ -336,10 +350,11 @@ func (def *UnionDefinition) GetOperation() string { // EnumDefinition implements Node, Definition type EnumDefinition struct { - Kind string - Loc *Location - Name *Name - Values []*EnumValueDefinition + Kind string + Loc *Location + Name *Name + Directives []*Directive + Values []*EnumValueDefinition } func NewEnumDefinition(def *EnumDefinition) *EnumDefinition { @@ -347,10 +362,11 @@ func NewEnumDefinition(def *EnumDefinition) *EnumDefinition { def = &EnumDefinition{} } return &EnumDefinition{ - Kind: kinds.EnumDefinition, - Loc: def.Loc, - Name: def.Name, - Values: def.Values, + Kind: kinds.EnumDefinition, + Loc: def.Loc, + Name: def.Name, + Directives: def.Directives, + Values: def.Values, } } @@ -380,9 +396,10 @@ func (def *EnumDefinition) GetOperation() string { // EnumValueDefinition implements Node, Definition type EnumValueDefinition struct { - Kind string - Loc *Location - Name *Name + Kind string + Loc *Location + Name *Name + Directives []*Directive } func NewEnumValueDefinition(def *EnumValueDefinition) *EnumValueDefinition { @@ -390,9 +407,10 @@ func NewEnumValueDefinition(def *EnumValueDefinition) *EnumValueDefinition { def = &EnumValueDefinition{} } return &EnumValueDefinition{ - Kind: kinds.EnumValueDefinition, - Loc: def.Loc, - Name: def.Name, + Kind: kinds.EnumValueDefinition, + Loc: def.Loc, + Name: def.Name, + Directives: def.Directives, } } @@ -406,10 +424,11 @@ func (def *EnumValueDefinition) GetLoc() *Location { // InputObjectDefinition implements Node, Definition type InputObjectDefinition struct { - Kind string - Loc *Location - Name *Name - Fields []*InputValueDefinition + Kind string + Loc *Location + Name *Name + Directives []*Directive + Fields []*InputValueDefinition } func NewInputObjectDefinition(def *InputObjectDefinition) *InputObjectDefinition { @@ -417,10 +436,11 @@ func NewInputObjectDefinition(def *InputObjectDefinition) *InputObjectDefinition def = &InputObjectDefinition{} } return &InputObjectDefinition{ - Kind: kinds.InputObjectDefinition, - Loc: def.Loc, - Name: def.Name, - Fields: def.Fields, + Kind: kinds.InputObjectDefinition, + Loc: def.Loc, + Name: def.Name, + Directives: def.Directives, + Fields: def.Fields, } } diff --git a/language/parser/parser.go b/language/parser/parser.go index 46ddf751..92cf7ac6 100644 --- a/language/parser/parser.go +++ b/language/parser/parser.go @@ -877,7 +877,7 @@ func parseNamed(parser *Parser) (*ast.Named, error) { /* Implements the parsing rules in the Type Definition section. */ /** - * SchemaDefinition : schema { OperationTypeDefinition+ } + * SchemaDefinition : schema Directives? { OperationTypeDefinition+ } * * OperationTypeDefinition : OperationType : NamedType */ @@ -887,6 +887,10 @@ func parseSchemaDefinition(parser *Parser) (*ast.SchemaDefinition, error) { if err != nil { return nil, err } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } operationTypesI, err := many( parser, lexer.TokenKind[lexer.BRACE_L], @@ -902,11 +906,11 @@ func parseSchemaDefinition(parser *Parser) (*ast.SchemaDefinition, error) { operationTypes = append(operationTypes, op) } } - def := ast.NewSchemaDefinition(&ast.SchemaDefinition{ + return ast.NewSchemaDefinition(&ast.SchemaDefinition{ OperationTypes: operationTypes, + Directives: directives, Loc: loc(parser, start), - }) - return def, nil + }), nil } func parseOperationTypeDefinition(parser *Parser) (interface{}, error) { @@ -931,7 +935,7 @@ func parseOperationTypeDefinition(parser *Parser) (interface{}, error) { } /** - * ScalarTypeDefinition : scalar Name + * ScalarTypeDefinition : scalar Name Directives? */ func parseScalarTypeDefinition(parser *Parser) (*ast.ScalarDefinition, error) { start := parser.Token.Start @@ -943,15 +947,20 @@ func parseScalarTypeDefinition(parser *Parser) (*ast.ScalarDefinition, error) { if err != nil { return nil, err } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } def := ast.NewScalarDefinition(&ast.ScalarDefinition{ - Name: name, - Loc: loc(parser, start), + Name: name, + Directives: directives, + Loc: loc(parser, start), }) return def, nil } /** - * ObjectTypeDefinition : type Name ImplementsInterfaces? { FieldDefinition+ } + * ObjectTypeDefinition : type Name ImplementsInterfaces? Directives? { FieldDefinition+ } */ func parseObjectTypeDefinition(parser *Parser) (*ast.ObjectDefinition, error) { start := parser.Token.Start @@ -967,6 +976,10 @@ func parseObjectTypeDefinition(parser *Parser) (*ast.ObjectDefinition, error) { if err != nil { return nil, err } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } iFields, err := any(parser, lexer.TokenKind[lexer.BRACE_L], parseFieldDefinition, lexer.TokenKind[lexer.BRACE_R]) if err != nil { return nil, err @@ -981,6 +994,7 @@ func parseObjectTypeDefinition(parser *Parser) (*ast.ObjectDefinition, error) { Name: name, Loc: loc(parser, start), Interfaces: interfaces, + Directives: directives, Fields: fields, }), nil } @@ -1000,7 +1014,7 @@ func parseImplementsInterfaces(parser *Parser) ([]*ast.Named, error) { return types, err } types = append(types, ttype) - if peek(parser, lexer.TokenKind[lexer.BRACE_L]) { + if !peek(parser, lexer.TokenKind[lexer.NAME]) { break } } @@ -1009,7 +1023,7 @@ func parseImplementsInterfaces(parser *Parser) ([]*ast.Named, error) { } /** - * FieldDefinition : Name ArgumentsDefinition? : Type + * FieldDefinition : Name ArgumentsDefinition? : Type Directives? */ func parseFieldDefinition(parser *Parser) (interface{}, error) { start := parser.Token.Start @@ -1029,11 +1043,16 @@ func parseFieldDefinition(parser *Parser) (interface{}, error) { if err != nil { return nil, err } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } return ast.NewFieldDefinition(&ast.FieldDefinition{ - Name: name, - Arguments: args, - Type: ttype, - Loc: loc(parser, start), + Name: name, + Arguments: args, + Type: ttype, + Directives: directives, + Loc: loc(parser, start), }), nil } @@ -1059,7 +1078,7 @@ func parseArgumentDefs(parser *Parser) ([]*ast.InputValueDefinition, error) { } /** - * InputValueDefinition : Name : Type DefaultValue? + * InputValueDefinition : Name : Type DefaultValue? Directives? */ func parseInputValueDef(parser *Parser) (interface{}, error) { start := parser.Token.Start @@ -1087,16 +1106,21 @@ func parseInputValueDef(parser *Parser) (interface{}, error) { defaultValue = val } } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } return ast.NewInputValueDefinition(&ast.InputValueDefinition{ Name: name, Type: ttype, DefaultValue: defaultValue, + Directives: directives, Loc: loc(parser, start), }), nil } /** - * InterfaceTypeDefinition : interface Name { FieldDefinition+ } + * InterfaceTypeDefinition : interface Name Directives? { FieldDefinition+ } */ func parseInterfaceTypeDefinition(parser *Parser) (*ast.InterfaceDefinition, error) { start := parser.Token.Start @@ -1108,6 +1132,10 @@ func parseInterfaceTypeDefinition(parser *Parser) (*ast.InterfaceDefinition, err if err != nil { return nil, err } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } iFields, err := any(parser, lexer.TokenKind[lexer.BRACE_L], parseFieldDefinition, lexer.TokenKind[lexer.BRACE_R]) if err != nil { return nil, err @@ -1119,14 +1147,15 @@ func parseInterfaceTypeDefinition(parser *Parser) (*ast.InterfaceDefinition, err } } return ast.NewInterfaceDefinition(&ast.InterfaceDefinition{ - Name: name, - Loc: loc(parser, start), - Fields: fields, + Name: name, + Directives: directives, + Loc: loc(parser, start), + Fields: fields, }), nil } /** - * UnionTypeDefinition : union Name = UnionMembers + * UnionTypeDefinition : union Name Directives? = UnionMembers */ func parseUnionTypeDefinition(parser *Parser) (*ast.UnionDefinition, error) { start := parser.Token.Start @@ -1138,6 +1167,10 @@ func parseUnionTypeDefinition(parser *Parser) (*ast.UnionDefinition, error) { if err != nil { return nil, err } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } _, err = expect(parser, lexer.TokenKind[lexer.EQUALS]) if err != nil { return nil, err @@ -1147,9 +1180,10 @@ func parseUnionTypeDefinition(parser *Parser) (*ast.UnionDefinition, error) { return nil, err } return ast.NewUnionDefinition(&ast.UnionDefinition{ - Name: name, - Loc: loc(parser, start), - Types: types, + Name: name, + Directives: directives, + Loc: loc(parser, start), + Types: types, }), nil } @@ -1176,7 +1210,7 @@ func parseUnionMembers(parser *Parser) ([]*ast.Named, error) { } /** - * EnumTypeDefinition : enum Name { EnumValueDefinition+ } + * EnumTypeDefinition : enum Name Directives? { EnumValueDefinition+ } */ func parseEnumTypeDefinition(parser *Parser) (*ast.EnumDefinition, error) { start := parser.Token.Start @@ -1188,6 +1222,10 @@ func parseEnumTypeDefinition(parser *Parser) (*ast.EnumDefinition, error) { if err != nil { return nil, err } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } iEnumValueDefs, err := any(parser, lexer.TokenKind[lexer.BRACE_L], parseEnumValueDefinition, lexer.TokenKind[lexer.BRACE_R]) if err != nil { return nil, err @@ -1199,14 +1237,15 @@ func parseEnumTypeDefinition(parser *Parser) (*ast.EnumDefinition, error) { } } return ast.NewEnumDefinition(&ast.EnumDefinition{ - Name: name, - Loc: loc(parser, start), - Values: values, + Name: name, + Directives: directives, + Loc: loc(parser, start), + Values: values, }), nil } /** - * EnumValueDefinition : EnumValue + * EnumValueDefinition : EnumValue Directives? * * EnumValue : Name */ @@ -1216,14 +1255,19 @@ func parseEnumValueDefinition(parser *Parser) (interface{}, error) { if err != nil { return nil, err } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } return ast.NewEnumValueDefinition(&ast.EnumValueDefinition{ - Name: name, - Loc: loc(parser, start), + Name: name, + Directives: directives, + Loc: loc(parser, start), }), nil } /** - * InputObjectTypeDefinition : input Name { InputValueDefinition+ } + * InputObjectTypeDefinition : input Name Directives? { InputValueDefinition+ } */ func parseInputObjectTypeDefinition(parser *Parser) (*ast.InputObjectDefinition, error) { start := parser.Token.Start @@ -1235,6 +1279,10 @@ func parseInputObjectTypeDefinition(parser *Parser) (*ast.InputObjectDefinition, if err != nil { return nil, err } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } iInputValueDefinitions, err := any(parser, lexer.TokenKind[lexer.BRACE_L], parseInputValueDef, lexer.TokenKind[lexer.BRACE_R]) if err != nil { return nil, err @@ -1246,9 +1294,10 @@ func parseInputObjectTypeDefinition(parser *Parser) (*ast.InputObjectDefinition, } } return ast.NewInputObjectDefinition(&ast.InputObjectDefinition{ - Name: name, - Loc: loc(parser, start), - Fields: fields, + Name: name, + Directives: directives, + Loc: loc(parser, start), + Fields: fields, }), nil } diff --git a/language/parser/schema_parser_test.go b/language/parser/schema_parser_test.go index 510c7d26..a72928cf 100644 --- a/language/parser/schema_parser_test.go +++ b/language/parser/schema_parser_test.go @@ -29,6 +29,7 @@ func testLoc(start int, end int) *ast.Location { Start: start, End: end, } } + func TestSchemaParser_SimpleType(t *testing.T) { body := ` @@ -45,6 +46,7 @@ type Hello { Value: "Hello", Loc: testLoc(6, 11), }), + Directives: []*ast.Directive{}, Interfaces: []*ast.Named{}, Fields: []*ast.FieldDefinition{ ast.NewFieldDefinition(&ast.FieldDefinition{ @@ -53,7 +55,8 @@ type Hello { Value: "world", Loc: testLoc(16, 21), }), - Arguments: []*ast.InputValueDefinition{}, + Arguments: []*ast.InputValueDefinition{}, + Directives: []*ast.Directive{}, Type: ast.NewNamed(&ast.Named{ Loc: testLoc(23, 29), Name: ast.NewName(&ast.Name{ @@ -89,6 +92,7 @@ extend type Hello { Value: "Hello", Loc: testLoc(13, 18), }), + Directives: []*ast.Directive{}, Interfaces: []*ast.Named{}, Fields: []*ast.FieldDefinition{ ast.NewFieldDefinition(&ast.FieldDefinition{ @@ -97,7 +101,8 @@ extend type Hello { Value: "world", Loc: testLoc(23, 28), }), - Arguments: []*ast.InputValueDefinition{}, + Directives: []*ast.Directive{}, + Arguments: []*ast.InputValueDefinition{}, Type: ast.NewNamed(&ast.Named{ Loc: testLoc(30, 36), Name: ast.NewName(&ast.Name{ @@ -132,6 +137,7 @@ type Hello { Value: "Hello", Loc: testLoc(6, 11), }), + Directives: []*ast.Directive{}, Interfaces: []*ast.Named{}, Fields: []*ast.FieldDefinition{ ast.NewFieldDefinition(&ast.FieldDefinition{ @@ -140,7 +146,8 @@ type Hello { Value: "world", Loc: testLoc(16, 21), }), - Arguments: []*ast.InputValueDefinition{}, + Directives: []*ast.Directive{}, + Arguments: []*ast.InputValueDefinition{}, Type: ast.NewNonNull(&ast.NonNull{ Kind: "NonNullType", Loc: testLoc(23, 30), @@ -174,6 +181,7 @@ func TestSchemaParser_SimpleTypeInheritingInterface(t *testing.T) { Value: "Hello", Loc: testLoc(5, 10), }), + Directives: []*ast.Directive{}, Interfaces: []*ast.Named{ ast.NewNamed(&ast.Named{ Name: ast.NewName(&ast.Name{ @@ -204,6 +212,7 @@ func TestSchemaParser_SimpleTypeInheritingMultipleInterfaces(t *testing.T) { Value: "Hello", Loc: testLoc(5, 10), }), + Directives: []*ast.Directive{}, Interfaces: []*ast.Named{ ast.NewNamed(&ast.Named{ Name: ast.NewName(&ast.Name{ @@ -241,13 +250,15 @@ func TestSchemaParser_SingleValueEnum(t *testing.T) { Value: "Hello", Loc: testLoc(5, 10), }), + Directives: []*ast.Directive{}, Values: []*ast.EnumValueDefinition{ ast.NewEnumValueDefinition(&ast.EnumValueDefinition{ Name: ast.NewName(&ast.Name{ Value: "WORLD", Loc: testLoc(13, 18), }), - Loc: testLoc(13, 18), + Directives: []*ast.Directive{}, + Loc: testLoc(13, 18), }), }, }), @@ -270,20 +281,23 @@ func TestSchemaParser_DoubleValueEnum(t *testing.T) { Value: "Hello", Loc: testLoc(5, 10), }), + Directives: []*ast.Directive{}, Values: []*ast.EnumValueDefinition{ ast.NewEnumValueDefinition(&ast.EnumValueDefinition{ Name: ast.NewName(&ast.Name{ Value: "WO", Loc: testLoc(13, 15), }), - Loc: testLoc(13, 15), + Directives: []*ast.Directive{}, + Loc: testLoc(13, 15), }), ast.NewEnumValueDefinition(&ast.EnumValueDefinition{ Name: ast.NewName(&ast.Name{ Value: "RLD", Loc: testLoc(17, 20), }), - Loc: testLoc(17, 20), + Directives: []*ast.Directive{}, + Loc: testLoc(17, 20), }), }, }), @@ -309,6 +323,7 @@ interface Hello { Value: "Hello", Loc: testLoc(11, 16), }), + Directives: []*ast.Directive{}, Fields: []*ast.FieldDefinition{ ast.NewFieldDefinition(&ast.FieldDefinition{ Loc: testLoc(21, 34), @@ -316,7 +331,8 @@ interface Hello { Value: "world", Loc: testLoc(21, 26), }), - Arguments: []*ast.InputValueDefinition{}, + Directives: []*ast.Directive{}, + Arguments: []*ast.InputValueDefinition{}, Type: ast.NewNamed(&ast.Named{ Loc: testLoc(28, 34), Name: ast.NewName(&ast.Name{ @@ -349,6 +365,7 @@ type Hello { Value: "Hello", Loc: testLoc(6, 11), }), + Directives: []*ast.Directive{}, Interfaces: []*ast.Named{}, Fields: []*ast.FieldDefinition{ ast.NewFieldDefinition(&ast.FieldDefinition{ @@ -357,6 +374,7 @@ type Hello { Value: "world", Loc: testLoc(16, 21), }), + Directives: []*ast.Directive{}, Arguments: []*ast.InputValueDefinition{ ast.NewInputValueDefinition(&ast.InputValueDefinition{ Loc: testLoc(22, 35), @@ -372,6 +390,7 @@ type Hello { }), }), DefaultValue: nil, + Directives: []*ast.Directive{}, }), }, Type: ast.NewNamed(&ast.Named{ @@ -406,6 +425,7 @@ type Hello { Value: "Hello", Loc: testLoc(6, 11), }), + Directives: []*ast.Directive{}, Interfaces: []*ast.Named{}, Fields: []*ast.FieldDefinition{ ast.NewFieldDefinition(&ast.FieldDefinition{ @@ -432,8 +452,10 @@ type Hello { Value: true, Loc: testLoc(38, 42), }), + Directives: []*ast.Directive{}, }), }, + Directives: []*ast.Directive{}, Type: ast.NewNamed(&ast.Named{ Loc: testLoc(45, 51), Name: ast.NewName(&ast.Name{ @@ -466,6 +488,7 @@ type Hello { Value: "Hello", Loc: testLoc(6, 11), }), + Directives: []*ast.Directive{}, Interfaces: []*ast.Named{}, Fields: []*ast.FieldDefinition{ ast.NewFieldDefinition(&ast.FieldDefinition{ @@ -474,6 +497,7 @@ type Hello { Value: "world", Loc: testLoc(16, 21), }), + Directives: []*ast.Directive{}, Arguments: []*ast.InputValueDefinition{ ast.NewInputValueDefinition(&ast.InputValueDefinition{ Loc: testLoc(22, 38), @@ -492,6 +516,7 @@ type Hello { }), }), DefaultValue: nil, + Directives: []*ast.Directive{}, }), }, Type: ast.NewNamed(&ast.Named{ @@ -526,6 +551,7 @@ type Hello { Value: "Hello", Loc: testLoc(6, 11), }), + Directives: []*ast.Directive{}, Interfaces: []*ast.Named{}, Fields: []*ast.FieldDefinition{ ast.NewFieldDefinition(&ast.FieldDefinition{ @@ -534,6 +560,7 @@ type Hello { Value: "world", Loc: testLoc(16, 21), }), + Directives: []*ast.Directive{}, Arguments: []*ast.InputValueDefinition{ ast.NewInputValueDefinition(&ast.InputValueDefinition{ Loc: testLoc(22, 37), @@ -549,6 +576,7 @@ type Hello { }), }), DefaultValue: nil, + Directives: []*ast.Directive{}, }), ast.NewInputValueDefinition(&ast.InputValueDefinition{ Loc: testLoc(39, 50), @@ -564,6 +592,7 @@ type Hello { }), }), DefaultValue: nil, + Directives: []*ast.Directive{}, }), }, Type: ast.NewNamed(&ast.Named{ @@ -595,6 +624,7 @@ func TestSchemaParser_SimpleUnion(t *testing.T) { Value: "Hello", Loc: testLoc(6, 11), }), + Directives: []*ast.Directive{}, Types: []*ast.Named{ ast.NewNamed(&ast.Named{ Loc: testLoc(14, 19), @@ -624,6 +654,7 @@ func TestSchemaParser_UnionWithTwoTypes(t *testing.T) { Value: "Hello", Loc: testLoc(6, 11), }), + Directives: []*ast.Directive{}, Types: []*ast.Named{ ast.NewNamed(&ast.Named{ Loc: testLoc(14, 16), @@ -660,6 +691,7 @@ func TestSchemaParser_Scalar(t *testing.T) { Value: "Hello", Loc: testLoc(7, 12), }), + Directives: []*ast.Directive{}, }), }, }) @@ -683,6 +715,7 @@ input Hello { Value: "Hello", Loc: testLoc(7, 12), }), + Directives: []*ast.Directive{}, Fields: []*ast.InputValueDefinition{ ast.NewInputValueDefinition(&ast.InputValueDefinition{ Loc: testLoc(17, 30), @@ -698,6 +731,7 @@ input Hello { }), }), DefaultValue: nil, + Directives: []*ast.Directive{}, }), }, }), diff --git a/language/printer/printer.go b/language/printer/printer.go index 47f9b2a4..3add3852 100644 --- a/language/printer/printer.go +++ b/language/printer/printer.go @@ -29,6 +29,23 @@ func getMapValue(m map[string]interface{}, key string) interface{} { } return valMap } +func getMapSliceValue(m map[string]interface{}, key string) []interface{} { + tokens := strings.Split(key, ".") + valMap := m + for _, token := range tokens { + v, ok := valMap[token] + if !ok { + return []interface{}{} + } + switch v := v.(type) { + case []interface{}: + return v + default: + return []interface{}{} + } + } + return []interface{}{} +} func getMapValueString(m map[string]interface{}, key string) string { tokens := strings.Split(key, ".") valMap := m @@ -92,11 +109,13 @@ func wrap(start, maybeString, end string) string { } return start + maybeString + end } + +// Given array, print each item on its own line, wrapped in an indented "{ }" block. func block(maybeArray interface{}) string { - if maybeArray == nil { - return "" - } s := toSliceString(maybeArray) + if len(s) == 0 { + return "{}" + } return indent("{\n"+join(s, "\n")) + "\n}" } @@ -436,13 +455,27 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ "SchemaDefinition": func(p visitor.VisitFuncParams) (string, interface{}) { switch node := p.Node.(type) { case *ast.SchemaDefinition: - operationTypesBlock := block(node.OperationTypes) - str := fmt.Sprintf("schema %v", operationTypesBlock) + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := join([]string{ + "schema", + join(directives, " "), + block(node.OperationTypes), + }, " ") return visitor.ActionUpdate, str case map[string]interface{}: operationTypes := toSliceString(getMapValue(node, "OperationTypes")) - operationTypesBlock := block(operationTypes) - str := fmt.Sprintf("schema %v", operationTypesBlock) + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := join([]string{ + "schema", + join(directives, " "), + block(operationTypes), + }, " ") return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil @@ -463,12 +496,27 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ "ScalarDefinition": func(p visitor.VisitFuncParams) (string, interface{}) { switch node := p.Node.(type) { case *ast.ScalarDefinition: - name := fmt.Sprintf("%v", node.Name) - str := "scalar " + name + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := join([]string{ + "scalar", + fmt.Sprintf("%v", node.Name), + join(directives, " "), + }, " ") return visitor.ActionUpdate, str case map[string]interface{}: name := getMapValueString(node, "Name") - str := "scalar " + name + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := join([]string{ + "scalar", + name, + join(directives, " "), + }, " ") return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil @@ -479,13 +527,33 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ name := fmt.Sprintf("%v", node.Name) interfaces := toSliceString(node.Interfaces) fields := node.Fields - str := "type " + name + " " + wrap("implements ", join(interfaces, ", "), " ") + block(fields) + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := join([]string{ + "type", + name, + wrap("implements ", join(interfaces, ", "), ""), + join(directives, " "), + block(fields), + }, " ") return visitor.ActionUpdate, str case map[string]interface{}: name := getMapValueString(node, "Name") interfaces := toSliceString(getMapValue(node, "Interfaces")) fields := getMapValue(node, "Fields") - str := "type " + name + " " + wrap("implements ", join(interfaces, ", "), " ") + block(fields) + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := join([]string{ + "type", + name, + wrap("implements ", join(interfaces, ", "), ""), + join(directives, " "), + block(fields), + }, " ") return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil @@ -496,13 +564,21 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ name := fmt.Sprintf("%v", node.Name) ttype := fmt.Sprintf("%v", node.Type) args := toSliceString(node.Arguments) - str := name + wrap("(", join(args, ", "), ")") + ": " + ttype + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := name + wrap("(", join(args, ", "), ")") + ": " + ttype + wrap(" ", join(directives, " "), "") return visitor.ActionUpdate, str case map[string]interface{}: name := getMapValueString(node, "Name") ttype := getMapValueString(node, "Type") args := toSliceString(getMapValue(node, "Arguments")) - str := name + wrap("(", join(args, ", "), ")") + ": " + ttype + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := name + wrap("(", join(args, ", "), ")") + ": " + ttype + wrap(" ", join(directives, " "), "") return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil @@ -513,13 +589,30 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ name := fmt.Sprintf("%v", node.Name) ttype := fmt.Sprintf("%v", node.Type) defaultValue := fmt.Sprintf("%v", node.DefaultValue) - str := name + ": " + ttype + wrap(" = ", defaultValue, "") + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := join([]string{ + name + ": " + ttype, + wrap("= ", defaultValue, ""), + join(directives, " "), + }, " ") + return visitor.ActionUpdate, str case map[string]interface{}: name := getMapValueString(node, "Name") ttype := getMapValueString(node, "Type") defaultValue := getMapValueString(node, "DefaultValue") - str := name + ": " + ttype + wrap(" = ", defaultValue, "") + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := join([]string{ + name + ": " + ttype, + wrap("= ", defaultValue, ""), + join(directives, " "), + }, " ") return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil @@ -529,12 +622,30 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ case *ast.InterfaceDefinition: name := fmt.Sprintf("%v", node.Name) fields := node.Fields - str := "interface " + name + " " + block(fields) + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := join([]string{ + "interface", + name, + join(directives, " "), + block(fields), + }, " ") return visitor.ActionUpdate, str case map[string]interface{}: name := getMapValueString(node, "Name") fields := getMapValue(node, "Fields") - str := "interface " + name + " " + block(fields) + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := join([]string{ + "interface", + name, + join(directives, " "), + block(fields), + }, " ") return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil @@ -544,12 +655,30 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ case *ast.UnionDefinition: name := fmt.Sprintf("%v", node.Name) types := toSliceString(node.Types) - str := "union " + name + " = " + join(types, " | ") + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := join([]string{ + "union", + name, + join(directives, " "), + "= " + join(types, " | "), + }, " ") return visitor.ActionUpdate, str case map[string]interface{}: name := getMapValueString(node, "Name") types := toSliceString(getMapValue(node, "Types")) - str := "union " + name + " = " + join(types, " | ") + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := join([]string{ + "union", + name, + join(directives, " "), + "= " + join(types, " | "), + }, " ") return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil @@ -559,12 +688,30 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ case *ast.EnumDefinition: name := fmt.Sprintf("%v", node.Name) values := node.Values - str := "enum " + name + " " + block(values) + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := join([]string{ + "enum", + name, + join(directives, " "), + block(values), + }, " ") return visitor.ActionUpdate, str case map[string]interface{}: name := getMapValueString(node, "Name") values := getMapValue(node, "Values") - str := "enum " + name + " " + block(values) + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := join([]string{ + "enum", + name, + join(directives, " "), + block(values), + }, " ") return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil @@ -573,10 +720,26 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ switch node := p.Node.(type) { case *ast.EnumValueDefinition: name := fmt.Sprintf("%v", node.Name) - return visitor.ActionUpdate, name + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := join([]string{ + name, + join(directives, " "), + }, " ") + return visitor.ActionUpdate, str case map[string]interface{}: name := getMapValueString(node, "Name") - return visitor.ActionUpdate, name + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := join([]string{ + name, + join(directives, " "), + }, " ") + return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil }, @@ -585,11 +748,31 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ case *ast.InputObjectDefinition: name := fmt.Sprintf("%v", node.Name) fields := node.Fields - return visitor.ActionUpdate, "input " + name + " " + block(fields) + directives := []string{} + for _, directive := range node.Directives { + directives = append(directives, fmt.Sprintf("%v", directive.Name)) + } + str := join([]string{ + "input", + name, + join(directives, " "), + block(fields), + }, " ") + return visitor.ActionUpdate, str case map[string]interface{}: name := getMapValueString(node, "Name") fields := getMapValue(node, "Fields") - return visitor.ActionUpdate, "input " + name + " " + block(fields) + directives := []string{} + for _, directive := range getMapSliceValue(node, "Directives") { + directives = append(directives, fmt.Sprintf("%v", directive)) + } + str := join([]string{ + "input", + name, + join(directives, " "), + block(fields), + }, " ") + return visitor.ActionUpdate, str } return visitor.ActionNoChange, nil }, diff --git a/language/printer/schema_printer_test.go b/language/printer/schema_printer_test.go index 3e0f6f69..2ce5ed00 100644 --- a/language/printer/schema_printer_test.go +++ b/language/printer/schema_printer_test.go @@ -67,29 +67,54 @@ type Foo implements Bar { six(argument: InputType = {key: "value"}): Type } +type AnnotatedObject @onObject(arg: "value") { + annotatedField(arg: Type = "default" @onArg): Type @onField +} + interface Bar { one: Type four(argument: String = "string"): String } +interface AnnotatedInterface @onInterface { + annotatedField(arg: Type @onArg): Type @onField +} + union Feed = Story | Article | Advert +union AnnotatedUnion @onUnion = A | B + scalar CustomScalar +scalar AnnotatedScalar @onScalar + enum Site { DESKTOP MOBILE } +enum AnnotatedEnum @onEnum { + ANNOTATED_VALUE @onEnumValue + OTHER_VALUE +} + input InputType { key: String! answer: Int = 42 } +input AnnotatedInput @onInputObjectType { + annotatedField: Type @onField +} + extend type Foo { seven(argument: [String]): Type } +extend type Foo @onType {} + +type NoFields {} + directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT directive @include(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT diff --git a/language/visitor/visitor.go b/language/visitor/visitor.go index b59df235..9a1c2ac2 100644 --- a/language/visitor/visitor.go +++ b/language/visitor/visitor.go @@ -83,40 +83,56 @@ var QueryDocumentKeys = KeyMap{ "List": []string{"Type"}, "NonNull": []string{"Type"}, - "SchemaDefinition": []string{"OperationTypes"}, + "SchemaDefinition": []string{ + "Directives", + "OperationTypes", + }, "OperationTypeDefinition": []string{"Type"}, - "ScalarDefinition": []string{"Name"}, + "ScalarDefinition": []string{ + "Name", + "Directives", + }, "ObjectDefinition": []string{ "Name", "Interfaces", + "Directives", "Fields", }, "FieldDefinition": []string{ "Name", "Arguments", "Type", + "Directives", }, "InputValueDefinition": []string{ "Name", "Type", "DefaultValue", + "Directives", }, "InterfaceDefinition": []string{ "Name", + "Directives", "Fields", }, "UnionDefinition": []string{ "Name", + "Directives", "Types", }, "EnumDefinition": []string{ "Name", + "Directives", "Values", }, - "EnumValueDefinition": []string{"Name"}, + "EnumValueDefinition": []string{ + "Name", + "Directives", + }, "InputObjectDefinition": []string{ "Name", + "Directives", "Fields", }, @@ -372,10 +388,13 @@ Loop: nodeIn = node } parentConcrete, _ := parent.(ast.Node) + // ancestorsConcrete slice may contain nil values ancestorsConcrete := []ast.Node{} for _, ancestor := range ancestors { if ancestorConcrete, ok := ancestor.(ast.Node); ok { ancestorsConcrete = append(ancestorsConcrete, ancestorConcrete) + } else { + ancestorsConcrete = append(ancestorsConcrete, nil) } } diff --git a/quoted_or_list_internal_test.go b/quoted_or_list_internal_test.go new file mode 100644 index 00000000..a2caccbe --- /dev/null +++ b/quoted_or_list_internal_test.go @@ -0,0 +1,35 @@ +package graphql + +import ( + "reflect" + "testing" +) + +func TestQuotedOrList_DoesNoAcceptAnEmptyList(t *testing.T) { + expected := "" + result := quotedOrList([]string{}) + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v, got: %v", expected, result) + } +} +func TestQuotedOrList_ReturnsSingleQuotedItem(t *testing.T) { + expected := `"A"` + result := quotedOrList([]string{"A"}) + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v, got: %v", expected, result) + } +} +func TestQuotedOrList_ReturnsTwoItems(t *testing.T) { + expected := `"A" or "B"` + result := quotedOrList([]string{"A", "B"}) + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v, got: %v", expected, result) + } +} +func TestQuotedOrList_ReturnsCommaSeparatedManyItemList(t *testing.T) { + expected := `"A", "B", "C", "D", or "E"` + result := quotedOrList([]string{"A", "B", "C", "D", "E", "F"}) + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v, got: %v", expected, result) + } +} diff --git a/rules.go b/rules.go index 3b3c4456..b234cf71 100644 --- a/rules.go +++ b/rules.go @@ -2,13 +2,15 @@ package graphql import ( "fmt" + "math" + "sort" + "strings" + "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/kinds" "github.com/graphql-go/graphql/language/printer" "github.com/graphql-go/graphql/language/visitor" - "sort" - "strings" ) // SpecifiedRules set includes all validation rules defined by the GraphQL spec. @@ -162,30 +164,40 @@ func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleI VisitorOpts: visitorOpts, } } - -func UndefinedFieldMessage(fieldName string, ttypeName string, suggestedTypes []string) string { - - quoteStrings := func(slice []string) []string { - quoted := []string{} - for _, s := range slice { - quoted = append(quoted, fmt.Sprintf(`"%v"`, s)) - } - return quoted +func quoteStrings(slice []string) []string { + quoted := []string{} + for _, s := range slice { + quoted = append(quoted, fmt.Sprintf(`"%v"`, s)) } + return quoted +} - // construct helpful (but long) message +// quotedOrList Given [ A, B, C ] return '"A", "B", or "C"'. +// Notice oxford comma +func quotedOrList(slice []string) string { + maxLength := 5 + if len(slice) == 0 { + return "" + } + quoted := quoteStrings(slice) + if maxLength > len(quoted) { + maxLength = len(quoted) + } + if maxLength > 2 { + return fmt.Sprintf("%v, or %v", strings.Join(quoted[0:maxLength-1], ", "), quoted[maxLength-1]) + } + if maxLength > 1 { + return fmt.Sprintf("%v or %v", strings.Join(quoted[0:maxLength-1], ", "), quoted[maxLength-1]) + } + return quoted[0] +} +func UndefinedFieldMessage(fieldName string, ttypeName string, suggestedTypeNames []string, suggestedFieldNames []string) string { message := fmt.Sprintf(`Cannot query field "%v" on type "%v".`, fieldName, ttypeName) - suggestions := strings.Join(quoteStrings(suggestedTypes), ", ") - const MaxLength = 5 - if len(suggestedTypes) > 0 { - if len(suggestedTypes) > MaxLength { - suggestions = strings.Join(quoteStrings(suggestedTypes[0:MaxLength]), ", ") + - fmt.Sprintf(`, and %v other types`, len(suggestedTypes)-MaxLength) - } - message = message + fmt.Sprintf(` However, this field exists on %v.`, suggestions) - message = message + ` Perhaps you meant to use an inline fragment?` + if len(suggestedTypeNames) > 0 { + message = fmt.Sprintf(`%v Did you mean to use an inline fragment on %v?`, message, quotedOrList(suggestedTypeNames)) + } else if len(suggestedFieldNames) > 0 { + message = fmt.Sprintf(`%v Did you mean %v?`, message, quotedOrList(suggestedFieldNames)) } - return message } @@ -206,37 +218,23 @@ func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance if ttype != nil { fieldDef := context.FieldDef() if fieldDef == nil { - // This isn't valid. Let's find suggestions, if any. - suggestedTypes := []string{} - + // This field doesn't exist, lets look for suggestions. nodeName := "" if node.Name != nil { nodeName = node.Name.Value } + // First determine if there are any suggested types to condition on. + suggestedTypeNames := getSuggestedTypeNames(context.Schema(), ttype, nodeName) - if ttype, ok := ttype.(Abstract); ok && IsAbstractType(ttype) { - siblingInterfaces := getSiblingInterfacesIncludingField(context.Schema(), ttype, nodeName) - implementations := getImplementationsIncludingField(context.Schema(), ttype, nodeName) - suggestedMaps := map[string]bool{} - for _, s := range siblingInterfaces { - if _, ok := suggestedMaps[s]; !ok { - suggestedMaps[s] = true - suggestedTypes = append(suggestedTypes, s) - } - } - for _, s := range implementations { - if _, ok := suggestedMaps[s]; !ok { - suggestedMaps[s] = true - suggestedTypes = append(suggestedTypes, s) - } - } + // If there are no suggested types, then perhaps this was a typo? + suggestedFieldNames := []string{} + if len(suggestedTypeNames) == 0 { + suggestedFieldNames = getSuggestedFieldNames(context.Schema(), ttype, nodeName) } - message := UndefinedFieldMessage(nodeName, ttype.Name(), suggestedTypes) - reportError( context, - message, + UndefinedFieldMessage(nodeName, ttype.Name(), suggestedTypeNames, suggestedFieldNames), []ast.Node{node}, ) } @@ -252,73 +250,100 @@ func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance } } -// Return implementations of `type` that include `fieldName` as a valid field. -func getImplementationsIncludingField(schema *Schema, ttype Abstract, fieldName string) []string { - - result := []string{} - for _, t := range schema.PossibleTypes(ttype) { - fields := t.Fields() - if _, ok := fields[fieldName]; ok { - result = append(result, fmt.Sprintf(`%v`, t.Name())) - } - } - - sort.Strings(result) - return result -} +// getSuggestedTypeNames Go through all of the implementations of type, as well as the interfaces +// that they implement. If any of those types include the provided field, +// suggest them, sorted by how often the type is referenced, starting +// with Interfaces. +func getSuggestedTypeNames(schema *Schema, ttype Output, fieldName string) []string { -// Go through all of the implementations of type, and find other interaces -// that they implement. If those interfaces include `field` as a valid field, -// return them, sorted by how often the implementations include the other -// interface. -func getSiblingInterfacesIncludingField(schema *Schema, ttype Abstract, fieldName string) []string { - implementingObjects := schema.PossibleTypes(ttype) + possibleTypes := schema.PossibleTypes(ttype) - result := []string{} - suggestedInterfaceSlice := []*suggestedInterface{} - - // stores a map of interface name => index in suggestedInterfaceSlice + suggestedObjectTypes := []string{} + suggestedInterfaces := []*suggestedInterface{} + // stores a map of interface name => index in suggestedInterfaces suggestedInterfaceMap := map[string]int{} + // stores a maps of object name => true to remove duplicates from results + suggestedObjectMap := map[string]bool{} - for _, t := range implementingObjects { - for _, i := range t.Interfaces() { - if i == nil { - continue - } - fields := i.Fields() - if _, ok := fields[fieldName]; !ok { + for _, possibleType := range possibleTypes { + if field, ok := possibleType.Fields()[fieldName]; !ok || field == nil { + continue + } + // This object type defines this field. + suggestedObjectTypes = append(suggestedObjectTypes, possibleType.Name()) + suggestedObjectMap[possibleType.Name()] = true + + for _, possibleInterface := range possibleType.Interfaces() { + if field, ok := possibleInterface.Fields()[fieldName]; !ok || field == nil { continue } - index, ok := suggestedInterfaceMap[i.Name()] + + // This interface type defines this field. + + // - find the index of the suggestedInterface and retrieving the interface + // - increase count + index, ok := suggestedInterfaceMap[possibleInterface.Name()] if !ok { - suggestedInterfaceSlice = append(suggestedInterfaceSlice, &suggestedInterface{ - name: i.Name(), + suggestedInterfaces = append(suggestedInterfaces, &suggestedInterface{ + name: possibleInterface.Name(), count: 0, }) - index = len(suggestedInterfaceSlice) - 1 + index = len(suggestedInterfaces) - 1 + suggestedInterfaceMap[possibleInterface.Name()] = index } - if index < len(suggestedInterfaceSlice) { - s := suggestedInterfaceSlice[index] - if s.name == i.Name() { + if index < len(suggestedInterfaces) { + s := suggestedInterfaces[index] + if s.name == possibleInterface.Name() { s.count = s.count + 1 } } } } - sort.Sort(suggestedInterfaceSortedSlice(suggestedInterfaceSlice)) - for _, s := range suggestedInterfaceSlice { - result = append(result, fmt.Sprintf(`%v`, s.name)) + // sort results (by count usage for interfaces, alphabetical order for objects) + sort.Sort(suggestedInterfaceSortedSlice(suggestedInterfaces)) + sort.Sort(sort.StringSlice(suggestedObjectTypes)) + + // return concatenated slices of both interface and object type names + // and removing duplicates + // ordered by: interface (sorted) and object (sorted) + results := []string{} + for _, s := range suggestedInterfaces { + if _, ok := suggestedObjectMap[s.name]; !ok { + results = append(results, s.name) + + } } - return result + results = append(results, suggestedObjectTypes...) + return results +} +// getSuggestedFieldNames For the field name provided, determine if there are any similar field names +// that may be the result of a typo. +func getSuggestedFieldNames(schema *Schema, ttype Output, fieldName string) []string { + + fields := FieldDefinitionMap{} + switch ttype := ttype.(type) { + case *Object: + fields = ttype.Fields() + case *Interface: + fields = ttype.Fields() + default: + return []string{} + } + + possibleFieldNames := []string{} + for possibleFieldName := range fields { + possibleFieldNames = append(possibleFieldNames, possibleFieldName) + } + return suggestionList(fieldName, possibleFieldNames) } +// suggestedInterface an internal struct to sort interface by usage count type suggestedInterface struct { name string count int } - type suggestedInterfaceSortedSlice []*suggestedInterface func (s suggestedInterfaceSortedSlice) Len() int { @@ -328,7 +353,10 @@ func (s suggestedInterfaceSortedSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func (s suggestedInterfaceSortedSlice) Less(i, j int) bool { - return s[i].count < s[j].count + if s[i].count == s[j].count { + return s[i].name < s[j].name + } + return s[i].count > s[j].count } // FragmentsOnCompositeTypesRule Fragments on composite type @@ -380,6 +408,26 @@ func FragmentsOnCompositeTypesRule(context *ValidationContext) *ValidationRuleIn } } +func unknownArgMessage(argName string, fieldName string, parentTypeName string, suggestedArgs []string) string { + message := fmt.Sprintf(`Unknown argument "%v" on field "%v" of type "%v".`, argName, fieldName, parentTypeName) + + if len(suggestedArgs) > 0 { + message = fmt.Sprintf(`%v Did you mean %v?`, message, quotedOrList(suggestedArgs)) + } + + return message +} + +func unknownDirectiveArgMessage(argName string, directiveName string, suggestedArgs []string) string { + message := fmt.Sprintf(`Unknown argument "%v" on directive "@%v".`, argName, directiveName) + + if len(suggestedArgs) > 0 { + message = fmt.Sprintf(`%v Did you mean %v?`, message, quotedOrList(suggestedArgs)) + } + + return message +} + // KnownArgumentNamesRule Known argument names // // A GraphQL field is only valid if all supplied arguments are defined by @@ -399,6 +447,7 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance if argumentOf == nil { return action, result } + var fieldArgDef *Argument if argumentOf.GetKind() == kinds.Field { fieldDef := context.FieldDef() if fieldDef == nil { @@ -408,8 +457,9 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance if node.Name != nil { nodeName = node.Name.Value } - var fieldArgDef *Argument + argNames := []string{} for _, arg := range fieldDef.Args { + argNames = append(argNames, arg.Name()) if arg.Name() == nodeName { fieldArgDef = arg } @@ -422,7 +472,7 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance } reportError( context, - fmt.Sprintf(`Unknown argument "%v" on field "%v" of type "%v".`, nodeName, fieldDef.Name, parentTypeName), + unknownArgMessage(nodeName, fieldDef.Name, parentTypeName, suggestionList(nodeName, argNames)), []ast.Node{node}, ) } @@ -435,8 +485,10 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance if node.Name != nil { nodeName = node.Name.Value } + argNames := []string{} var directiveArgDef *Argument for _, arg := range directive.Args { + argNames = append(argNames, arg.Name()) if arg.Name() == nodeName { directiveArgDef = arg } @@ -444,7 +496,7 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance if directiveArgDef == nil { reportError( context, - fmt.Sprintf(`Unknown argument "%v" on directive "@%v".`, nodeName, directive.Name), + unknownDirectiveArgMessage(nodeName, directive.Name, suggestionList(nodeName, argNames)), []ast.Node{node}, ) } @@ -497,15 +549,7 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { ) } - var appliedTo ast.Node - if len(p.Ancestors) > 0 { - appliedTo = p.Ancestors[len(p.Ancestors)-1] - } - if appliedTo == nil { - return action, result - } - - candidateLocation := getLocationForAppliedNode(appliedTo) + candidateLocation := getDirectiveLocationForASTPath(p.Ancestors) directiveHasLocation := false for _, loc := range directiveDef.Locations { @@ -540,7 +584,14 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { } } -func getLocationForAppliedNode(appliedTo ast.Node) string { +func getDirectiveLocationForASTPath(ancestors []ast.Node) string { + var appliedTo ast.Node + if len(ancestors) > 0 { + appliedTo = ancestors[len(ancestors)-1] + } + if appliedTo == nil { + return "" + } kind := appliedTo.GetKind() if kind == kinds.OperationDefinition { appliedTo, _ := appliedTo.(*ast.OperationDefinition) @@ -566,6 +617,44 @@ func getLocationForAppliedNode(appliedTo ast.Node) string { if kind == kinds.FragmentDefinition { return DirectiveLocationFragmentDefinition } + if kind == kinds.SchemaDefinition { + return DirectiveLocationSchema + } + if kind == kinds.ScalarDefinition { + return DirectiveLocationScalar + } + if kind == kinds.ObjectDefinition { + return DirectiveLocationObject + } + if kind == kinds.FieldDefinition { + return DirectiveLocationFieldDefinition + } + if kind == kinds.InterfaceDefinition { + return DirectiveLocationInterface + } + if kind == kinds.UnionDefinition { + return DirectiveLocationUnion + } + if kind == kinds.EnumDefinition { + return DirectiveLocationEnum + } + if kind == kinds.EnumValueDefinition { + return DirectiveLocationEnumValue + } + if kind == kinds.InputObjectDefinition { + return DirectiveLocationInputObject + } + if kind == kinds.InputValueDefinition { + var parentNode ast.Node + if len(ancestors) >= 3 { + parentNode = ancestors[len(ancestors)-3] + } + if parentNode.GetKind() == kinds.InputObjectDefinition { + return DirectiveLocationInputFieldDefinition + } else { + return DirectiveLocationArgumentDefinition + } + } return "" } @@ -606,6 +695,15 @@ func KnownFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance } } +func unknownTypeMessage(typeName string, suggestedTypes []string) string { + message := fmt.Sprintf(`Unknown type "%v".`, typeName) + if len(suggestedTypes) > 0 { + message = fmt.Sprintf(`%v Did you mean %v?`, message, quotedOrList(suggestedTypes)) + } + + return message +} + // KnownTypeNamesRule Known type names // // A GraphQL document is only valid if referenced types (specifically @@ -643,9 +741,13 @@ func KnownTypeNamesRule(context *ValidationContext) *ValidationRuleInstance { } ttype := context.Schema().Type(typeNameValue) if ttype == nil { + suggestedTypes := []string{} + for key := range context.Schema().TypeMap() { + suggestedTypes = append(suggestedTypes, key) + } reportError( context, - fmt.Sprintf(`Unknown type "%v".`, typeNameValue), + unknownTypeMessage(typeNameValue, suggestionList(typeNameValue, suggestedTypes)), []ast.Node{node}, ) } @@ -702,27 +804,6 @@ func LoneAnonymousOperationRule(context *ValidationContext) *ValidationRuleInsta } } -type nodeSet struct { - set map[ast.Node]bool -} - -func newNodeSet() *nodeSet { - return &nodeSet{ - set: map[ast.Node]bool{}, - } -} -func (set *nodeSet) Has(node ast.Node) bool { - _, ok := set.set[node] - return ok -} -func (set *nodeSet) Add(node ast.Node) bool { - if set.Has(node) { - return false - } - set.set[node] = true - return true -} - func CycleErrorMessage(fragName string, spreadNames []string) string { via := "" if len(spreadNames) > 0 { @@ -756,7 +837,7 @@ func NoFragmentCyclesRule(context *ValidationContext) *ValidationRuleInstance { } visitedFrags[fragmentName] = true - spreadNodes := context.FragmentSpreads(fragment) + spreadNodes := context.FragmentSpreads(fragment.SelectionSet) if len(spreadNodes) == 0 { return } @@ -1048,489 +1129,6 @@ func NoUnusedVariablesRule(context *ValidationContext) *ValidationRuleInstance { } } -type fieldDefPair struct { - ParentType Composite - Field *ast.Field - FieldDef *FieldDefinition -} - -func collectFieldASTsAndDefs(context *ValidationContext, parentType Named, selectionSet *ast.SelectionSet, visitedFragmentNames map[string]bool, astAndDefs map[string][]*fieldDefPair) map[string][]*fieldDefPair { - - if astAndDefs == nil { - astAndDefs = map[string][]*fieldDefPair{} - } - if visitedFragmentNames == nil { - visitedFragmentNames = map[string]bool{} - } - if selectionSet == nil { - return astAndDefs - } - for _, selection := range selectionSet.Selections { - switch selection := selection.(type) { - case *ast.Field: - fieldName := "" - if selection.Name != nil { - fieldName = selection.Name.Value - } - var fieldDef *FieldDefinition - if parentType, ok := parentType.(*Object); ok { - fieldDef, _ = parentType.Fields()[fieldName] - } - if parentType, ok := parentType.(*Interface); ok { - fieldDef, _ = parentType.Fields()[fieldName] - } - - responseName := fieldName - if selection.Alias != nil { - responseName = selection.Alias.Value - } - _, ok := astAndDefs[responseName] - if !ok { - astAndDefs[responseName] = []*fieldDefPair{} - } - if parentType, ok := parentType.(Composite); ok { - astAndDefs[responseName] = append(astAndDefs[responseName], &fieldDefPair{ - ParentType: parentType, - Field: selection, - FieldDef: fieldDef, - }) - } else { - astAndDefs[responseName] = append(astAndDefs[responseName], &fieldDefPair{ - Field: selection, - FieldDef: fieldDef, - }) - } - case *ast.InlineFragment: - inlineFragmentType := parentType - if selection.TypeCondition != nil { - parentType, _ := typeFromAST(*context.Schema(), selection.TypeCondition) - inlineFragmentType = parentType - } - astAndDefs = collectFieldASTsAndDefs( - context, - inlineFragmentType, - selection.SelectionSet, - visitedFragmentNames, - astAndDefs, - ) - case *ast.FragmentSpread: - fragName := "" - if selection.Name != nil { - fragName = selection.Name.Value - } - if _, ok := visitedFragmentNames[fragName]; ok { - continue - } - visitedFragmentNames[fragName] = true - fragment := context.Fragment(fragName) - if fragment == nil { - continue - } - parentType, _ := typeFromAST(*context.Schema(), fragment.TypeCondition) - astAndDefs = collectFieldASTsAndDefs( - context, - parentType, - fragment.SelectionSet, - visitedFragmentNames, - astAndDefs, - ) - } - } - return astAndDefs -} - -// pairSet A way to keep track of pairs of things when the ordering of the pair does -// not matter. We do this by maintaining a sort of double adjacency sets. -type pairSet struct { - data map[ast.Node]*nodeSet -} - -func newPairSet() *pairSet { - return &pairSet{ - data: map[ast.Node]*nodeSet{}, - } -} -func (pair *pairSet) Has(a ast.Node, b ast.Node) bool { - first, ok := pair.data[a] - if !ok || first == nil { - return false - } - res := first.Has(b) - return res -} -func (pair *pairSet) Add(a ast.Node, b ast.Node) bool { - pair.data = pairSetAdd(pair.data, a, b) - pair.data = pairSetAdd(pair.data, b, a) - return true -} - -func pairSetAdd(data map[ast.Node]*nodeSet, a, b ast.Node) map[ast.Node]*nodeSet { - set, ok := data[a] - if !ok || set == nil { - set = newNodeSet() - data[a] = set - } - set.Add(b) - return data -} - -type conflictReason struct { - Name string - Message interface{} // conflictReason || []conflictReason -} -type conflict struct { - Reason conflictReason - FieldsLeft []ast.Node - FieldsRight []ast.Node -} - -func sameArguments(args1 []*ast.Argument, args2 []*ast.Argument) bool { - if len(args1) != len(args2) { - return false - } - - for _, arg1 := range args1 { - arg1Name := "" - if arg1.Name != nil { - arg1Name = arg1.Name.Value - } - - var foundArgs2 *ast.Argument - for _, arg2 := range args2 { - arg2Name := "" - if arg2.Name != nil { - arg2Name = arg2.Name.Value - } - if arg1Name == arg2Name { - foundArgs2 = arg2 - } - break - } - if foundArgs2 == nil { - return false - } - if sameValue(arg1.Value, foundArgs2.Value) == false { - return false - } - } - - return true -} -func sameValue(value1 ast.Value, value2 ast.Value) bool { - if value1 == nil && value2 == nil { - return true - } - val1 := printer.Print(value1) - val2 := printer.Print(value2) - - return val1 == val2 -} - -func sameType(typeA, typeB Type) bool { - if typeA == typeB { - return true - } - - if typeA, ok := typeA.(*List); ok { - if typeB, ok := typeB.(*List); ok { - return sameType(typeA.OfType, typeB.OfType) - } - } - if typeA, ok := typeA.(*NonNull); ok { - if typeB, ok := typeB.(*NonNull); ok { - return sameType(typeA.OfType, typeB.OfType) - } - } - - return false -} - -// Two types conflict if both types could not apply to a value simultaneously. -// Composite types are ignored as their individual field types will be compared -// later recursively. However List and Non-Null types must match. -func doTypesConflict(type1 Output, type2 Output) bool { - if type1, ok := type1.(*List); ok { - if type2, ok := type2.(*List); ok { - return doTypesConflict(type1.OfType, type2.OfType) - } - return true - } - if type2, ok := type2.(*List); ok { - if type1, ok := type1.(*List); ok { - return doTypesConflict(type1.OfType, type2.OfType) - } - return true - } - if type1, ok := type1.(*NonNull); ok { - if type2, ok := type2.(*NonNull); ok { - return doTypesConflict(type1.OfType, type2.OfType) - } - return true - } - if type2, ok := type2.(*NonNull); ok { - if type1, ok := type1.(*NonNull); ok { - return doTypesConflict(type1.OfType, type2.OfType) - } - return true - } - if IsLeafType(type1) || IsLeafType(type2) { - return type1 != type2 - } - return false -} - -// OverlappingFieldsCanBeMergedRule Overlapping fields can be merged -// -// A selection set is only valid if all fields (including spreading any -// fragments) either correspond to distinct response names or can be merged -// without ambiguity. -func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRuleInstance { - - var getSubfieldMap func(ast1 *ast.Field, type1 Output, ast2 *ast.Field, type2 Output) map[string][]*fieldDefPair - var subfieldConflicts func(conflicts []*conflict, responseName string, ast1 *ast.Field, ast2 *ast.Field) *conflict - var findConflicts func(parentFieldsAreMutuallyExclusive bool, fieldMap map[string][]*fieldDefPair) (conflicts []*conflict) - - comparedSet := newPairSet() - findConflict := func(parentFieldsAreMutuallyExclusive bool, responseName string, field *fieldDefPair, field2 *fieldDefPair) *conflict { - - parentType1 := field.ParentType - ast1 := field.Field - def1 := field.FieldDef - - parentType2 := field2.ParentType - ast2 := field2.Field - def2 := field2.FieldDef - - // Not a pair. - if ast1 == ast2 { - return nil - } - - // Memoize, do not report the same issue twice. - // Note: Two overlapping ASTs could be encountered both when - // `parentFieldsAreMutuallyExclusive` is true and is false, which could - // produce different results (when `true` being a subset of `false`). - // However we do not need to include this piece of information when - // memoizing since this rule visits leaf fields before their parent fields, - // ensuring that `parentFieldsAreMutuallyExclusive` is `false` the first - // time two overlapping fields are encountered, ensuring that the full - // set of validation rules are always checked when necessary. - if comparedSet.Has(ast1, ast2) { - return nil - } - comparedSet.Add(ast1, ast2) - - // The return type for each field. - var type1 Type - var type2 Type - if def1 != nil { - type1 = def1.Type - } - if def2 != nil { - type2 = def2.Type - } - - // If it is known that two fields could not possibly apply at the same - // time, due to the parent types, then it is safe to permit them to diverge - // in aliased field or arguments used as they will not present any ambiguity - // by differing. - // It is known that two parent types could never overlap if they are - // different Object types. Interface or Union types might overlap - if not - // in the current state of the schema, then perhaps in some future version, - // thus may not safely diverge. - _, isParentType1Object := parentType1.(*Object) - _, isParentType2Object := parentType2.(*Object) - fieldsAreMutuallyExclusive := parentFieldsAreMutuallyExclusive || parentType1 != parentType2 && isParentType1Object && isParentType2Object - - if !fieldsAreMutuallyExclusive { - // Two aliases must refer to the same field. - name1 := "" - name2 := "" - - if ast1.Name != nil { - name1 = ast1.Name.Value - } - if ast2.Name != nil { - name2 = ast2.Name.Value - } - if name1 != name2 { - return &conflict{ - Reason: conflictReason{ - Name: responseName, - Message: fmt.Sprintf(`%v and %v are different fields`, name1, name2), - }, - FieldsLeft: []ast.Node{ast1}, - FieldsRight: []ast.Node{ast2}, - } - } - - // Two field calls must have the same arguments. - if !sameArguments(ast1.Arguments, ast2.Arguments) { - return &conflict{ - Reason: conflictReason{ - Name: responseName, - Message: `they have differing arguments`, - }, - FieldsLeft: []ast.Node{ast1}, - FieldsRight: []ast.Node{ast2}, - } - } - } - - if type1 != nil && type2 != nil && doTypesConflict(type1, type2) { - return &conflict{ - Reason: conflictReason{ - Name: responseName, - Message: fmt.Sprintf(`they return conflicting types %v and %v`, type1, type2), - }, - FieldsLeft: []ast.Node{ast1}, - FieldsRight: []ast.Node{ast2}, - } - } - - subFieldMap := getSubfieldMap(ast1, type1, ast2, type2) - if subFieldMap != nil { - conflicts := findConflicts(fieldsAreMutuallyExclusive, subFieldMap) - return subfieldConflicts(conflicts, responseName, ast1, ast2) - } - - return nil - } - - getSubfieldMap = func(ast1 *ast.Field, type1 Output, ast2 *ast.Field, type2 Output) map[string][]*fieldDefPair { - selectionSet1 := ast1.SelectionSet - selectionSet2 := ast2.SelectionSet - if selectionSet1 != nil && selectionSet2 != nil { - visitedFragmentNames := map[string]bool{} - subfieldMap := collectFieldASTsAndDefs( - context, - GetNamed(type1), - selectionSet1, - visitedFragmentNames, - nil, - ) - subfieldMap = collectFieldASTsAndDefs( - context, - GetNamed(type2), - selectionSet2, - visitedFragmentNames, - subfieldMap, - ) - return subfieldMap - } - return nil - } - - subfieldConflicts = func(conflicts []*conflict, responseName string, ast1 *ast.Field, ast2 *ast.Field) *conflict { - if len(conflicts) > 0 { - conflictReasons := []conflictReason{} - conflictFieldsLeft := []ast.Node{ast1} - conflictFieldsRight := []ast.Node{ast2} - for _, c := range conflicts { - conflictReasons = append(conflictReasons, c.Reason) - conflictFieldsLeft = append(conflictFieldsLeft, c.FieldsLeft...) - conflictFieldsRight = append(conflictFieldsRight, c.FieldsRight...) - } - - return &conflict{ - Reason: conflictReason{ - Name: responseName, - Message: conflictReasons, - }, - FieldsLeft: conflictFieldsLeft, - FieldsRight: conflictFieldsRight, - } - } - return nil - } - findConflicts = func(parentFieldsAreMutuallyExclusive bool, fieldMap map[string][]*fieldDefPair) (conflicts []*conflict) { - - // ensure field traversal - orderedName := sort.StringSlice{} - for responseName := range fieldMap { - orderedName = append(orderedName, responseName) - } - orderedName.Sort() - - for _, responseName := range orderedName { - fields, _ := fieldMap[responseName] - for _, fieldA := range fields { - for _, fieldB := range fields { - c := findConflict(parentFieldsAreMutuallyExclusive, responseName, fieldA, fieldB) - if c != nil { - conflicts = append(conflicts, c) - } - } - } - } - return conflicts - } - - var reasonMessage func(message interface{}) string - reasonMessage = func(message interface{}) string { - switch reason := message.(type) { - case string: - return reason - case conflictReason: - return reasonMessage(reason.Message) - case []conflictReason: - messages := []string{} - for _, r := range reason { - messages = append(messages, fmt.Sprintf( - `subfields "%v" conflict because %v`, - r.Name, - reasonMessage(r.Message), - )) - } - return strings.Join(messages, " and ") - } - return "" - } - - visitorOpts := &visitor.VisitorOptions{ - KindFuncMap: map[string]visitor.NamedVisitFuncs{ - kinds.SelectionSet: { - // Note: we validate on the reverse traversal so deeper conflicts will be - // caught first, for correct calculation of mutual exclusivity and for - // clearer error messages. - Leave: func(p visitor.VisitFuncParams) (string, interface{}) { - if selectionSet, ok := p.Node.(*ast.SelectionSet); ok && selectionSet != nil { - parentType, _ := context.ParentType().(Named) - fieldMap := collectFieldASTsAndDefs( - context, - parentType, - selectionSet, - nil, - nil, - ) - conflicts := findConflicts(false, fieldMap) - if len(conflicts) > 0 { - for _, c := range conflicts { - responseName := c.Reason.Name - reason := c.Reason - reportError( - context, - fmt.Sprintf( - `Fields "%v" conflict because %v.`, - responseName, - reasonMessage(reason), - ), - append(c.FieldsLeft, c.FieldsRight...), - ) - } - return visitor.ActionNoChange, nil - } - } - return visitor.ActionNoChange, nil - }, - }, - }, - } - return &ValidationRuleInstance{ - VisitorOpts: visitorOpts, - } -} - func getFragmentType(context *ValidationContext, name string) Type { frag := context.Fragment(name) if frag == nil { @@ -2210,3 +1808,85 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { return true, nil } + +// Internal struct to sort results from suggestionList() +type suggestionListResult struct { + Options []string + Distances []float64 +} + +func (s suggestionListResult) Len() int { + return len(s.Options) +} +func (s suggestionListResult) Swap(i, j int) { + s.Options[i], s.Options[j] = s.Options[j], s.Options[i] +} +func (s suggestionListResult) Less(i, j int) bool { + return s.Distances[i] < s.Distances[j] +} + +// suggestionList Given an invalid input string and a list of valid options, returns a filtered +// list of valid options sorted based on their similarity with the input. +func suggestionList(input string, options []string) []string { + dists := []float64{} + filteredOpts := []string{} + inputThreshold := float64(len(input) / 2) + + for _, opt := range options { + dist := lexicalDistance(input, opt) + threshold := math.Max(inputThreshold, float64(len(opt)/2)) + threshold = math.Max(threshold, 1) + if dist <= threshold { + filteredOpts = append(filteredOpts, opt) + dists = append(dists, dist) + } + } + //sort results + suggested := suggestionListResult{filteredOpts, dists} + sort.Sort(suggested) + return suggested.Options +} + +// lexicalDistance Computes the lexical distance between strings A and B. +// The "distance" between two strings is given by counting the minimum number +// of edits needed to transform string A into string B. An edit can be an +// insertion, deletion, or substitution of a single character, or a swap of two +// adjacent characters. +// This distance can be useful for detecting typos in input or sorting +func lexicalDistance(a, b string) float64 { + d := [][]float64{} + aLen := len(a) + bLen := len(b) + for i := 0; i <= aLen; i++ { + d = append(d, []float64{float64(i)}) + } + for k := 1; k <= bLen; k++ { + d[0] = append(d[0], float64(k)) + } + + for i := 1; i <= aLen; i++ { + for k := 1; k <= bLen; k++ { + cost := 1.0 + if a[i-1] == b[k-1] { + cost = 0.0 + } + minCostFloat := math.Min( + d[i-1][k]+1.0, + d[i][k-1]+1.0, + ) + minCostFloat = math.Min( + minCostFloat, + d[i-1][k-1]+cost, + ) + d[i] = append(d[i], minCostFloat) + + if i > 1 && k < 1 && + a[i-1] == b[k-2] && + a[i-2] == b[k-1] { + d[i][k] = math.Min(d[i][k], d[i-2][k-2]+cost) + } + } + } + + return d[aLen][bLen] +} diff --git a/rules_fields_on_correct_type_test.go b/rules_fields_on_correct_type_test.go index 294a0682..833a8349 100644 --- a/rules_fields_on_correct_type_test.go +++ b/rules_fields_on_correct_type_test.go @@ -73,7 +73,7 @@ func TestValidate_FieldsOnCorrectType_FieldNotDefinedOnFragment(t *testing.T) { meowVolume } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "meowVolume" on type "Dog".`, 3, 9), + testutil.RuleError(`Cannot query field "meowVolume" on type "Dog". Did you mean "barkVolume"?`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_IgnoreDeeplyUnknownField(t *testing.T) { @@ -106,7 +106,7 @@ func TestValidate_FieldsOnCorrectType_FieldNotDefinedOnInlineFragment(t *testing } } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "meowVolume" on type "Dog".`, 4, 11), + testutil.RuleError(`Cannot query field "meowVolume" on type "Dog". Did you mean "barkVolume"?`, 4, 11), }) } func TestValidate_FieldsOnCorrectType_AliasedFieldTargetNotDefined(t *testing.T) { @@ -115,7 +115,7 @@ func TestValidate_FieldsOnCorrectType_AliasedFieldTargetNotDefined(t *testing.T) volume : mooVolume } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "mooVolume" on type "Dog".`, 3, 9), + testutil.RuleError(`Cannot query field "mooVolume" on type "Dog". Did you mean "barkVolume"?`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_AliasedLyingFieldTargetNotDefined(t *testing.T) { @@ -124,7 +124,7 @@ func TestValidate_FieldsOnCorrectType_AliasedLyingFieldTargetNotDefined(t *testi barkVolume : kawVolume } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "kawVolume" on type "Dog".`, 3, 9), + testutil.RuleError(`Cannot query field "kawVolume" on type "Dog". Did you mean "barkVolume"?`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_NotDefinedOnInterface(t *testing.T) { @@ -142,7 +142,7 @@ func TestValidate_FieldsOnCorrectType_DefinedOnImplementorsButNotOnInterface(t * nickname } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "nickname" on type "Pet". However, this field exists on "Cat", "Dog". Perhaps you meant to use an inline fragment?`, 3, 9), + testutil.RuleError(`Cannot query field "nickname" on type "Pet". Did you mean to use an inline fragment on "Cat" or "Dog"?`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_MetaFieldSelectionOnUnion(t *testing.T) { @@ -167,7 +167,7 @@ func TestValidate_FieldsOnCorrectType_DefinedImplementorsQueriedOnUnion(t *testi name } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Cannot query field "name" on type "CatOrDog". However, this field exists on "Being", "Pet", "Canine", "Cat", "Dog". Perhaps you meant to use an inline fragment?`, 3, 9), + testutil.RuleError(`Cannot query field "name" on type "CatOrDog". Did you mean to use an inline fragment on "Being", "Pet", "Canine", "Cat", or "Dog"?`, 3, 9), }) } func TestValidate_FieldsOnCorrectType_ValidFieldInInlineFragment(t *testing.T) { @@ -184,27 +184,54 @@ func TestValidate_FieldsOnCorrectType_ValidFieldInInlineFragment(t *testing.T) { } func TestValidate_FieldsOnCorrectTypeErrorMessage_WorksWithNoSuggestions(t *testing.T) { - message := graphql.UndefinedFieldMessage("T", "f", []string{}) - expected := `Cannot query field "T" on type "f".` + message := graphql.UndefinedFieldMessage("f", "T", []string{}, []string{}) + expected := `Cannot query field "f" on type "T".` if message != expected { t.Fatalf("Unexpected message, expected: %v, got %v", expected, message) } } -func TestValidate_FieldsOnCorrectTypeErrorMessage_WorksWithNoSmallNumbersOfSuggestions(t *testing.T) { - message := graphql.UndefinedFieldMessage("T", "f", []string{"A", "B"}) - expected := `Cannot query field "T" on type "f". ` + - `However, this field exists on "A", "B". ` + - `Perhaps you meant to use an inline fragment?` +func TestValidate_FieldsOnCorrectTypeErrorMessage_WorksWithNoSmallNumbersOfTypeSuggestions(t *testing.T) { + message := graphql.UndefinedFieldMessage("f", "T", []string{"A", "B"}, []string{}) + expected := `Cannot query field "f" on type "T". ` + + `Did you mean to use an inline fragment on "A" or "B"?` if message != expected { t.Fatalf("Unexpected message, expected: %v, got %v", expected, message) } } -func TestValidate_FieldsOnCorrectTypeErrorMessage_WorksWithLotsOfSuggestions(t *testing.T) { - message := graphql.UndefinedFieldMessage("T", "f", []string{"A", "B", "C", "D", "E", "F"}) - expected := `Cannot query field "T" on type "f". ` + - `However, this field exists on "A", "B", "C", "D", "E", and 1 other types. ` + - `Perhaps you meant to use an inline fragment?` + +func TestValidate_FieldsOnCorrectTypeErrorMessage_WorksWithNoSmallNumbersOfFieldSuggestions(t *testing.T) { + message := graphql.UndefinedFieldMessage("f", "T", []string{}, []string{"z", "y"}) + expected := `Cannot query field "f" on type "T". ` + + `Did you mean "z" or "y"?` + if message != expected { + t.Fatalf("Unexpected message, expected: %v, got %v", expected, message) + } +} +func TestValidate_FieldsOnCorrectTypeErrorMessage_OnlyShowsOneSetOfSuggestionsAtATimePreferringTypes(t *testing.T) { + message := graphql.UndefinedFieldMessage("f", "T", []string{"A", "B"}, []string{"z", "y"}) + expected := `Cannot query field "f" on type "T". ` + + `Did you mean to use an inline fragment on "A" or "B"?` + if message != expected { + t.Fatalf("Unexpected message, expected: %v, got %v", expected, message) + } +} + +func TestValidate_FieldsOnCorrectTypeErrorMessage_LimitLotsOfTypeSuggestions(t *testing.T) { + message := graphql.UndefinedFieldMessage("f", "T", []string{"A", "B", "C", "D", "E", "F"}, []string{}) + expected := `Cannot query field "f" on type "T". ` + + `Did you mean to use an inline fragment on "A", "B", "C", "D", or "E"?` + if message != expected { + t.Fatalf("Unexpected message, expected: %v, got %v", expected, message) + } +} + +func TestValidate_FieldsOnCorrectTypeErrorMessage_LimitLotsOfFieldSuggestions(t *testing.T) { + message := graphql.UndefinedFieldMessage( + "f", "T", []string{}, []string{"z", "y", "x", "w", "v", "u"}, + ) + expected := `Cannot query field "f" on type "T". ` + + `Did you mean "z", "y", "x", "w", or "v"?` if message != expected { t.Fatalf("Unexpected message, expected: %v, got %v", expected, message) } diff --git a/rules_known_argument_names_test.go b/rules_known_argument_names_test.go index 7536161d..574a0037 100644 --- a/rules_known_argument_names_test.go +++ b/rules_known_argument_names_test.go @@ -75,6 +75,16 @@ func TestValidate_KnownArgumentNames_UndirectiveArgsAreInvalid(t *testing.T) { testutil.RuleError(`Unknown argument "unless" on directive "@skip".`, 3, 19), }) } +func TestValidate_KnownArgumentNames_UndirectiveArgsAreInvalidWithSuggestion(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.KnownArgumentNamesRule, ` + { + dog @skip(of: true) + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Unknown argument "of" on directive "@skip". `+ + `Did you mean "if"?`, 3, 19), + }) +} func TestValidate_KnownArgumentNames_InvalidArgName(t *testing.T) { testutil.ExpectFailsRule(t, graphql.KnownArgumentNamesRule, ` fragment invalidArgName on Dog { @@ -94,6 +104,16 @@ func TestValidate_KnownArgumentNames_UnknownArgsAmongstKnownArgs(t *testing.T) { testutil.RuleError(`Unknown argument "unknown" on field "doesKnowCommand" of type "Dog".`, 3, 55), }) } +func TestValidate_KnownArgumentNames_UnknownArgsAmongstKnownArgsWithSuggestions(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.KnownArgumentNamesRule, ` + fragment oneGoodArgOneInvalidArg on Dog { + doesKnowCommand(ddogCommand: SIT,) + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Unknown argument "ddogCommand" on field "doesKnowCommand" of type "Dog". `+ + `Did you mean "dogCommand"?`, 3, 25), + }) +} func TestValidate_KnownArgumentNames_UnknownArgsDeeply(t *testing.T) { testutil.ExpectFailsRule(t, graphql.KnownArgumentNamesRule, ` { diff --git a/rules_known_directives_rule_test.go b/rules_known_directives_rule_test.go index ea6c07e9..f3d8231c 100644 --- a/rules_known_directives_rule_test.go +++ b/rules_known_directives_rule_test.go @@ -64,23 +64,102 @@ func TestValidate_KnownDirectives_WithManyUnknownDirectives(t *testing.T) { } func TestValidate_KnownDirectives_WithWellPlacedDirectives(t *testing.T) { testutil.ExpectPassesRule(t, graphql.KnownDirectivesRule, ` - query Foo { + query Foo @onQuery { name @include(if: true) ...Frag @include(if: true) skippedField @skip(if: true) ...SkippedFrag @skip(if: true) } + + mutation Bar @onMutation { + someField + } `) } func TestValidate_KnownDirectives_WithMisplacedDirectives(t *testing.T) { testutil.ExpectFailsRule(t, graphql.KnownDirectivesRule, ` query Foo @include(if: true) { - name @operationOnly - ...Frag @operationOnly + name @onQuery + ...Frag @onQuery + } + + mutation Bar @onQuery { + someField } `, []gqlerrors.FormattedError{ testutil.RuleError(`Directive "include" may not be used on QUERY.`, 2, 17), - testutil.RuleError(`Directive "operationOnly" may not be used on FIELD.`, 3, 14), - testutil.RuleError(`Directive "operationOnly" may not be used on FRAGMENT_SPREAD.`, 4, 17), + testutil.RuleError(`Directive "onQuery" may not be used on FIELD.`, 3, 14), + testutil.RuleError(`Directive "onQuery" may not be used on FRAGMENT_SPREAD.`, 4, 17), + testutil.RuleError(`Directive "onQuery" may not be used on MUTATION.`, 7, 20), + }) +} + +func TestValidate_KnownDirectives_WithinSchemaLanguage_WithWellPlacedDirectives(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.KnownDirectivesRule, ` + type MyObj implements MyInterface @onObject { + myField(myArg: Int @onArgumentDefinition): String @onFieldDefinition + } + + scalar MyScalar @onScalar + + interface MyInterface @onInterface { + myField(myArg: Int @onArgumentDefinition): String @onFieldDefinition + } + + union MyUnion @onUnion = MyObj | Other + + enum MyEnum @onEnum { + MY_VALUE @onEnumValue + } + + input MyInput @onInputObject { + myField: Int @onInputFieldDefinition + } + + schema @onSchema { + query: MyQuery + } + `) +} + +func TestValidate_KnownDirectives_WithinSchemaLanguage_WithMisplacedDirectives(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.KnownDirectivesRule, ` + type MyObj implements MyInterface @onInterface { + myField(myArg: Int @onInputFieldDefinition): String @onInputFieldDefinition + } + + scalar MyScalar @onEnum + + interface MyInterface @onObject { + myField(myArg: Int @onInputFieldDefinition): String @onInputFieldDefinition + } + + union MyUnion @onEnumValue = MyObj | Other + + enum MyEnum @onScalar { + MY_VALUE @onUnion + } + + input MyInput @onEnum { + myField: Int @onArgumentDefinition + } + + schema @onObject { + query: MyQuery + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Directive "onInterface" may not be used on OBJECT.`, 2, 43), + testutil.RuleError(`Directive "onInputFieldDefinition" may not be used on ARGUMENT_DEFINITION.`, 3, 30), + testutil.RuleError(`Directive "onInputFieldDefinition" may not be used on FIELD_DEFINITION.`, 3, 63), + testutil.RuleError(`Directive "onEnum" may not be used on SCALAR.`, 6, 25), + testutil.RuleError(`Directive "onObject" may not be used on INTERFACE.`, 8, 31), + testutil.RuleError(`Directive "onInputFieldDefinition" may not be used on ARGUMENT_DEFINITION.`, 9, 30), + testutil.RuleError(`Directive "onInputFieldDefinition" may not be used on FIELD_DEFINITION.`, 9, 63), + testutil.RuleError(`Directive "onEnumValue" may not be used on UNION.`, 12, 23), + testutil.RuleError(`Directive "onScalar" may not be used on ENUM.`, 14, 21), + testutil.RuleError(`Directive "onUnion" may not be used on ENUM_VALUE.`, 15, 20), + testutil.RuleError(`Directive "onEnum" may not be used on INPUT_OBJECT.`, 18, 23), + testutil.RuleError(`Directive "onArgumentDefinition" may not be used on INPUT_FIELD_DEFINITION.`, 19, 24), + testutil.RuleError(`Directive "onObject" may not be used on SCHEMA.`, 22, 16), }) } diff --git a/rules_known_type_names_test.go b/rules_known_type_names_test.go index eec9a0ae..611a8037 100644 --- a/rules_known_type_names_test.go +++ b/rules_known_type_names_test.go @@ -34,7 +34,7 @@ func TestValidate_KnownTypeNames_UnknownTypeNamesAreInValid(t *testing.T) { `, []gqlerrors.FormattedError{ testutil.RuleError(`Unknown type "JumbledUpLetters".`, 2, 23), testutil.RuleError(`Unknown type "Badger".`, 5, 25), - testutil.RuleError(`Unknown type "Peettt".`, 8, 29), + testutil.RuleError(`Unknown type "Peettt". Did you mean "Pet"?`, 8, 29), }) } diff --git a/rules_overlapping_fields_can_be_merged.go b/rules_overlapping_fields_can_be_merged.go new file mode 100644 index 00000000..f44849b5 --- /dev/null +++ b/rules_overlapping_fields_can_be_merged.go @@ -0,0 +1,706 @@ +package graphql + +import ( + "fmt" + "strings" + + "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/kinds" + "github.com/graphql-go/graphql/language/printer" + "github.com/graphql-go/graphql/language/visitor" +) + +func fieldsConflictMessage(responseName string, reason conflictReason) string { + return fmt.Sprintf(`Fields "%v" conflict because %v. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + responseName, + fieldsConflictReasonMessage(reason), + ) +} + +func fieldsConflictReasonMessage(message interface{}) string { + switch reason := message.(type) { + case string: + return reason + case conflictReason: + return fieldsConflictReasonMessage(reason.Message) + case []conflictReason: + messages := []string{} + for _, r := range reason { + messages = append(messages, fmt.Sprintf( + `subfields "%v" conflict because %v`, + r.Name, + fieldsConflictReasonMessage(r.Message), + )) + } + return strings.Join(messages, " and ") + } + return "" +} + +// OverlappingFieldsCanBeMergedRule Overlapping fields can be merged +// +// A selection set is only valid if all fields (including spreading any +// fragments) either correspond to distinct response names or can be merged +// without ambiguity. +func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRuleInstance { + + // A memoization for when two fragments are compared "between" each other for + // conflicts. Two fragments may be compared many times, so memoizing this can + // dramatically improve the performance of this validator. + comparedSet := newPairSet() + + // A cache for the "field map" and list of fragment names found in any given + // selection set. Selection sets may be asked for this information multiple + // times, so this improves the performance of this validator. + cacheMap := map[*ast.SelectionSet]*fieldsAndFragmentNames{} + + visitorOpts := &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.SelectionSet: { + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + if selectionSet, ok := p.Node.(*ast.SelectionSet); ok && selectionSet != nil { + parentType, _ := context.ParentType().(Named) + + rule := &overlappingFieldsCanBeMergedRule{ + context: context, + comparedSet: comparedSet, + cacheMap: cacheMap, + } + conflicts := rule.findConflictsWithinSelectionSet(parentType, selectionSet) + if len(conflicts) > 0 { + for _, c := range conflicts { + responseName := c.Reason.Name + reason := c.Reason + reportError( + context, + fieldsConflictMessage(responseName, reason), + append(c.FieldsLeft, c.FieldsRight...), + ) + } + return visitor.ActionNoChange, nil + } + } + return visitor.ActionNoChange, nil + }, + }, + }, + } + return &ValidationRuleInstance{ + VisitorOpts: visitorOpts, + } +} + +/** + * Algorithm: + * + * Conflicts occur when two fields exist in a query which will produce the same + * response name, but represent differing values, thus creating a conflict. + * The algorithm below finds all conflicts via making a series of comparisons + * between fields. In order to compare as few fields as possible, this makes + * a series of comparisons "within" sets of fields and "between" sets of fields. + * + * Given any selection set, a collection produces both a set of fields by + * also including all inline fragments, as well as a list of fragments + * referenced by fragment spreads. + * + * A) Each selection set represented in the document first compares "within" its + * collected set of fields, finding any conflicts between every pair of + * overlapping fields. + * Note: This is the *only time* that a the fields "within" a set are compared + * to each other. After this only fields "between" sets are compared. + * + * B) Also, if any fragment is referenced in a selection set, then a + * comparison is made "between" the original set of fields and the + * referenced fragment. + * + * C) Also, if multiple fragments are referenced, then comparisons + * are made "between" each referenced fragment. + * + * D) When comparing "between" a set of fields and a referenced fragment, first + * a comparison is made between each field in the original set of fields and + * each field in the the referenced set of fields. + * + * E) Also, if any fragment is referenced in the referenced selection set, + * then a comparison is made "between" the original set of fields and the + * referenced fragment (recursively referring to step D). + * + * F) When comparing "between" two fragments, first a comparison is made between + * each field in the first referenced set of fields and each field in the the + * second referenced set of fields. + * + * G) Also, any fragments referenced by the first must be compared to the + * second, and any fragments referenced by the second must be compared to the + * first (recursively referring to step F). + * + * H) When comparing two fields, if both have selection sets, then a comparison + * is made "between" both selection sets, first comparing the set of fields in + * the first selection set with the set of fields in the second. + * + * I) Also, if any fragment is referenced in either selection set, then a + * comparison is made "between" the other set of fields and the + * referenced fragment. + * + * J) Also, if two fragments are referenced in both selection sets, then a + * comparison is made "between" the two fragments. + * + */ + +type overlappingFieldsCanBeMergedRule struct { + context *ValidationContext + + // A memoization for when two fragments are compared "between" each other for + // conflicts. Two fragments may be compared many times, so memoizing this can + // dramatically improve the performance of this validator. + comparedSet *pairSet + + // A cache for the "field map" and list of fragment names found in any given + // selection set. Selection sets may be asked for this information multiple + // times, so this improves the performance of this validator. + cacheMap map[*ast.SelectionSet]*fieldsAndFragmentNames +} + +// Find all conflicts found "within" a selection set, including those found +// via spreading in fragments. Called when visiting each SelectionSet in the +// GraphQL Document. +func (rule *overlappingFieldsCanBeMergedRule) findConflictsWithinSelectionSet(parentType Named, selectionSet *ast.SelectionSet) []conflict { + conflicts := []conflict{} + + fieldsInfo := rule.getFieldsAndFragmentNames(parentType, selectionSet) + + // (A) Find find all conflicts "within" the fields of this selection set. + // Note: this is the *only place* `collectConflictsWithin` is called. + conflicts = rule.collectConflictsWithin(conflicts, fieldsInfo) + + // (B) Then collect conflicts between these fields and those represented by + // each spread fragment name found. + for i := 0; i < len(fieldsInfo.fragmentNames); i++ { + + conflicts = rule.collectConflictsBetweenFieldsAndFragment(conflicts, false, fieldsInfo, fieldsInfo.fragmentNames[i]) + + // (C) Then compare this fragment with all other fragments found in this + // selection set to collect conflicts between fragments spread together. + // This compares each item in the list of fragment names to every other item + // in that same list (except for itself). + for k := i + 1; k < len(fieldsInfo.fragmentNames); k++ { + conflicts = rule.collectConflictsBetweenFragments(conflicts, false, fieldsInfo.fragmentNames[i], fieldsInfo.fragmentNames[k]) + } + } + return conflicts +} + +// Collect all conflicts found between a set of fields and a fragment reference +// including via spreading in any nested fragments. +func (rule *overlappingFieldsCanBeMergedRule) collectConflictsBetweenFieldsAndFragment(conflicts []conflict, areMutuallyExclusive bool, fieldsInfo *fieldsAndFragmentNames, fragmentName string) []conflict { + fragment := rule.context.Fragment(fragmentName) + if fragment == nil { + return conflicts + } + + fieldsInfo2 := rule.getReferencedFieldsAndFragmentNames(fragment) + + // (D) First collect any conflicts between the provided collection of fields + // and the collection of fields represented by the given fragment. + conflicts = rule.collectConflictsBetween(conflicts, areMutuallyExclusive, fieldsInfo, fieldsInfo2) + + // (E) Then collect any conflicts between the provided collection of fields + // and any fragment names found in the given fragment. + for _, fragmentName2 := range fieldsInfo2.fragmentNames { + conflicts = rule.collectConflictsBetweenFieldsAndFragment(conflicts, areMutuallyExclusive, fieldsInfo2, fragmentName2) + } + + return conflicts + +} + +// Collect all conflicts found between two fragments, including via spreading in +// any nested fragments. +func (rule *overlappingFieldsCanBeMergedRule) collectConflictsBetweenFragments(conflicts []conflict, areMutuallyExclusive bool, fragmentName1 string, fragmentName2 string) []conflict { + fragment1 := rule.context.Fragment(fragmentName1) + fragment2 := rule.context.Fragment(fragmentName2) + + if fragment1 == nil || fragment2 == nil { + return conflicts + } + + // No need to compare a fragment to itself. + if fragment1 == fragment2 { + return conflicts + } + + // Memoize so two fragments are not compared for conflicts more than once. + if rule.comparedSet.Has(fragmentName1, fragmentName2, areMutuallyExclusive) { + return conflicts + } + rule.comparedSet.Add(fragmentName1, fragmentName2, areMutuallyExclusive) + + fieldsInfo1 := rule.getReferencedFieldsAndFragmentNames(fragment1) + fieldsInfo2 := rule.getReferencedFieldsAndFragmentNames(fragment2) + + // (F) First, collect all conflicts between these two collections of fields + // (not including any nested fragments). + conflicts = rule.collectConflictsBetween(conflicts, areMutuallyExclusive, fieldsInfo1, fieldsInfo2) + + // (G) Then collect conflicts between the first fragment and any nested + // fragments spread in the second fragment. + for _, innerFragmentName2 := range fieldsInfo2.fragmentNames { + conflicts = rule.collectConflictsBetweenFragments(conflicts, areMutuallyExclusive, fragmentName1, innerFragmentName2) + } + + // (G) Then collect conflicts between the second fragment and any nested + // fragments spread in the first fragment. + for _, innerFragmentName1 := range fieldsInfo1.fragmentNames { + conflicts = rule.collectConflictsBetweenFragments(conflicts, areMutuallyExclusive, innerFragmentName1, fragmentName2) + } + + return conflicts +} + +// Find all conflicts found between two selection sets, including those found +// via spreading in fragments. Called when determining if conflicts exist +// between the sub-fields of two overlapping fields. +func (rule *overlappingFieldsCanBeMergedRule) findConflictsBetweenSubSelectionSets(areMutuallyExclusive bool, parentType1 Named, selectionSet1 *ast.SelectionSet, parentType2 Named, selectionSet2 *ast.SelectionSet) []conflict { + conflicts := []conflict{} + + fieldsInfo1 := rule.getFieldsAndFragmentNames(parentType1, selectionSet1) + fieldsInfo2 := rule.getFieldsAndFragmentNames(parentType2, selectionSet2) + + // (H) First, collect all conflicts between these two collections of field. + conflicts = rule.collectConflictsBetween(conflicts, areMutuallyExclusive, fieldsInfo1, fieldsInfo2) + + // (I) Then collect conflicts between the first collection of fields and + // those referenced by each fragment name associated with the second. + for _, fragmentName2 := range fieldsInfo2.fragmentNames { + conflicts = rule.collectConflictsBetweenFieldsAndFragment(conflicts, areMutuallyExclusive, fieldsInfo1, fragmentName2) + } + + // (I) Then collect conflicts between the second collection of fields and + // those referenced by each fragment name associated with the first. + for _, fragmentName1 := range fieldsInfo1.fragmentNames { + conflicts = rule.collectConflictsBetweenFieldsAndFragment(conflicts, areMutuallyExclusive, fieldsInfo2, fragmentName1) + } + + // (J) Also collect conflicts between any fragment names by the first and + // fragment names by the second. This compares each item in the first set of + // names to each item in the second set of names. + for _, fragmentName1 := range fieldsInfo1.fragmentNames { + for _, fragmentName2 := range fieldsInfo2.fragmentNames { + conflicts = rule.collectConflictsBetweenFragments(conflicts, areMutuallyExclusive, fragmentName1, fragmentName2) + } + } + return conflicts +} + +// Collect all Conflicts "within" one collection of fields. +func (rule *overlappingFieldsCanBeMergedRule) collectConflictsWithin(conflicts []conflict, fieldsInfo *fieldsAndFragmentNames) []conflict { + // A field map is a keyed collection, where each key represents a response + // name and the value at that key is a list of all fields which provide that + // response name. For every response name, if there are multiple fields, they + // must be compared to find a potential conflict. + for _, responseName := range fieldsInfo.fieldsOrder { + fields, ok := fieldsInfo.fieldMap[responseName] + if !ok { + continue + } + // This compares every field in the list to every other field in this list + // (except to itself). If the list only has one item, nothing needs to + // be compared. + if len(fields) <= 1 { + continue + } + for i := 0; i < len(fields); i++ { + for k := i + 1; k < len(fields); k++ { + // within one collection is never mutually exclusive + isMutuallyExclusive := false + conflict := rule.findConflict(isMutuallyExclusive, responseName, fields[i], fields[k]) + if conflict != nil { + conflicts = append(conflicts, *conflict) + } + } + } + } + return conflicts +} + +// Collect all Conflicts between two collections of fields. This is similar to, +// but different from the `collectConflictsWithin` function above. This check +// assumes that `collectConflictsWithin` has already been called on each +// provided collection of fields. This is true because this validator traverses +// each individual selection set. +func (rule *overlappingFieldsCanBeMergedRule) collectConflictsBetween(conflicts []conflict, parentFieldsAreMutuallyExclusive bool, + fieldsInfo1 *fieldsAndFragmentNames, + fieldsInfo2 *fieldsAndFragmentNames) []conflict { + // A field map is a keyed collection, where each key represents a response + // name and the value at that key is a list of all fields which provide that + // response name. For any response name which appears in both provided field + // maps, each field from the first field map must be compared to every field + // in the second field map to find potential conflicts. + for _, responseName := range fieldsInfo1.fieldsOrder { + fields1, ok1 := fieldsInfo1.fieldMap[responseName] + fields2, ok2 := fieldsInfo2.fieldMap[responseName] + if !ok1 || !ok2 { + continue + } + for i := 0; i < len(fields1); i++ { + for k := 0; k < len(fields2); k++ { + conflict := rule.findConflict(parentFieldsAreMutuallyExclusive, responseName, fields1[i], fields2[k]) + if conflict != nil { + conflicts = append(conflicts, *conflict) + } + } + } + } + return conflicts +} + +// findConflict Determines if there is a conflict between two particular fields. +func (rule *overlappingFieldsCanBeMergedRule) findConflict(parentFieldsAreMutuallyExclusive bool, responseName string, field *fieldDefPair, field2 *fieldDefPair) *conflict { + + parentType1 := field.ParentType + ast1 := field.Field + def1 := field.FieldDef + + parentType2 := field2.ParentType + ast2 := field2.Field + def2 := field2.FieldDef + + // If it is known that two fields could not possibly apply at the same + // time, due to the parent types, then it is safe to permit them to diverge + // in aliased field or arguments used as they will not present any ambiguity + // by differing. + // It is known that two parent types could never overlap if they are + // different Object types. Interface or Union types might overlap - if not + // in the current state of the schema, then perhaps in some future version, + // thus may not safely diverge. + _, isParentType1Object := parentType1.(*Object) + _, isParentType2Object := parentType2.(*Object) + areMutuallyExclusive := parentFieldsAreMutuallyExclusive || parentType1 != parentType2 && isParentType1Object && isParentType2Object + + // The return type for each field. + var type1 Type + var type2 Type + if def1 != nil { + type1 = def1.Type + } + if def2 != nil { + type2 = def2.Type + } + + if !areMutuallyExclusive { + // Two aliases must refer to the same field. + name1 := "" + name2 := "" + + if ast1.Name != nil { + name1 = ast1.Name.Value + } + if ast2.Name != nil { + name2 = ast2.Name.Value + } + if name1 != name2 { + return &conflict{ + Reason: conflictReason{ + Name: responseName, + Message: fmt.Sprintf(`%v and %v are different fields`, name1, name2), + }, + FieldsLeft: []ast.Node{ast1}, + FieldsRight: []ast.Node{ast2}, + } + } + + // Two field calls must have the same arguments. + if !sameArguments(ast1.Arguments, ast2.Arguments) { + return &conflict{ + Reason: conflictReason{ + Name: responseName, + Message: `they have differing arguments`, + }, + FieldsLeft: []ast.Node{ast1}, + FieldsRight: []ast.Node{ast2}, + } + } + } + + if type1 != nil && type2 != nil && doTypesConflict(type1, type2) { + return &conflict{ + Reason: conflictReason{ + Name: responseName, + Message: fmt.Sprintf(`they return conflicting types %v and %v`, type1, type2), + }, + FieldsLeft: []ast.Node{ast1}, + FieldsRight: []ast.Node{ast2}, + } + } + + // Collect and compare sub-fields. Use the same "visited fragment names" list + // for both collections so fields in a fragment reference are never + // compared to themselves. + selectionSet1 := ast1.SelectionSet + selectionSet2 := ast2.SelectionSet + if selectionSet1 != nil && selectionSet2 != nil { + conflicts := rule.findConflictsBetweenSubSelectionSets(areMutuallyExclusive, GetNamed(type1), selectionSet1, GetNamed(type2), selectionSet2) + return subfieldConflicts(conflicts, responseName, ast1, ast2) + } + return nil +} + +// Given a selection set, return the collection of fields (a mapping of response +// name to field ASTs and definitions) as well as a list of fragment names +// referenced via fragment spreads. +func (rule *overlappingFieldsCanBeMergedRule) getFieldsAndFragmentNames(parentType Named, selectionSet *ast.SelectionSet) *fieldsAndFragmentNames { + if cached, ok := rule.cacheMap[selectionSet]; ok && cached != nil { + return cached + } + + astAndDefs := astAndDefCollection{} + fieldsOrder := []string{} + fragmentNames := []string{} + fragmentNamesMap := map[string]bool{} + + var collectFieldsAndFragmentNames func(parentType Named, selectionSet *ast.SelectionSet) + collectFieldsAndFragmentNames = func(parentType Named, selectionSet *ast.SelectionSet) { + for _, selection := range selectionSet.Selections { + switch selection := selection.(type) { + case *ast.Field: + fieldName := "" + if selection.Name != nil { + fieldName = selection.Name.Value + } + var fieldDef *FieldDefinition + if parentType, ok := parentType.(*Object); ok { + fieldDef, _ = parentType.Fields()[fieldName] + } + if parentType, ok := parentType.(*Interface); ok { + fieldDef, _ = parentType.Fields()[fieldName] + } + + responseName := fieldName + if selection.Alias != nil { + responseName = selection.Alias.Value + } + + fieldDefPairs, ok := astAndDefs[responseName] + if !ok || fieldDefPairs == nil { + fieldDefPairs = []*fieldDefPair{} + fieldsOrder = append(fieldsOrder, responseName) + } + + fieldDefPairs = append(fieldDefPairs, &fieldDefPair{ + ParentType: parentType, + Field: selection, + FieldDef: fieldDef, + }) + astAndDefs[responseName] = fieldDefPairs + case *ast.FragmentSpread: + fieldName := "" + if selection.Name != nil { + fieldName = selection.Name.Value + } + if val, ok := fragmentNamesMap[fieldName]; !ok || !val { + fragmentNamesMap[fieldName] = true + fragmentNames = append(fragmentNames, fieldName) + } + case *ast.InlineFragment: + typeCondition := selection.TypeCondition + inlineFragmentType := parentType + if typeCondition != nil { + ttype, err := typeFromAST(*(rule.context.Schema()), typeCondition) + if err == nil { + inlineFragmentType, _ = ttype.(Named) + } + } + collectFieldsAndFragmentNames(inlineFragmentType, selection.SelectionSet) + } + } + } + collectFieldsAndFragmentNames(parentType, selectionSet) + + cached := &fieldsAndFragmentNames{ + fieldMap: astAndDefs, + fieldsOrder: fieldsOrder, + fragmentNames: fragmentNames, + } + + rule.cacheMap[selectionSet] = cached + return cached +} + +func (rule *overlappingFieldsCanBeMergedRule) getReferencedFieldsAndFragmentNames(fragment *ast.FragmentDefinition) *fieldsAndFragmentNames { + // Short-circuit building a type from the AST if possible. + if cached, ok := rule.cacheMap[fragment.SelectionSet]; ok && cached != nil { + return cached + } + fragmentType, err := typeFromAST(*(rule.context.Schema()), fragment.TypeCondition) + if err != nil { + return nil + } + return rule.getFieldsAndFragmentNames(fragmentType, fragment.SelectionSet) +} + +type conflictReason struct { + Name string + Message interface{} // conflictReason || []conflictReason +} +type conflict struct { + Reason conflictReason + FieldsLeft []ast.Node + FieldsRight []ast.Node +} + +// a.k.a AstAndDef +type fieldDefPair struct { + ParentType Named + Field *ast.Field + FieldDef *FieldDefinition +} +type astAndDefCollection map[string][]*fieldDefPair + +// cache struct for fields, its order and fragments names +type fieldsAndFragmentNames struct { + fieldMap astAndDefCollection + fieldsOrder []string // stores the order of field names in fieldMap + fragmentNames []string +} + +// pairSet A way to keep track of pairs of things when the ordering of the pair does +// not matter. We do this by maintaining a sort of double adjacency sets. +type pairSet struct { + data map[string]map[string]bool +} + +func newPairSet() *pairSet { + return &pairSet{ + data: map[string]map[string]bool{}, + } +} +func (pair *pairSet) Has(a string, b string, areMutuallyExclusive bool) bool { + first, ok := pair.data[a] + if !ok || first == nil { + return false + } + res, ok := first[b] + if !ok { + return false + } + // areMutuallyExclusive being false is a superset of being true, + // hence if we want to know if this PairSet "has" these two with no + // exclusivity, we have to ensure it was added as such. + if !areMutuallyExclusive { + return res == false + } + return true +} +func (pair *pairSet) Add(a string, b string, areMutuallyExclusive bool) { + pair.data = pairSetAdd(pair.data, a, b, areMutuallyExclusive) + pair.data = pairSetAdd(pair.data, b, a, areMutuallyExclusive) +} +func pairSetAdd(data map[string]map[string]bool, a, b string, areMutuallyExclusive bool) map[string]map[string]bool { + set, ok := data[a] + if !ok || set == nil { + set = map[string]bool{} + } + set[b] = areMutuallyExclusive + data[a] = set + return data +} + +func sameArguments(args1 []*ast.Argument, args2 []*ast.Argument) bool { + if len(args1) != len(args2) { + return false + } + + for _, arg1 := range args1 { + arg1Name := "" + if arg1.Name != nil { + arg1Name = arg1.Name.Value + } + + var foundArgs2 *ast.Argument + for _, arg2 := range args2 { + arg2Name := "" + if arg2.Name != nil { + arg2Name = arg2.Name.Value + } + if arg1Name == arg2Name { + foundArgs2 = arg2 + } + break + } + if foundArgs2 == nil { + return false + } + if sameValue(arg1.Value, foundArgs2.Value) == false { + return false + } + } + + return true +} + +func sameValue(value1 ast.Value, value2 ast.Value) bool { + if value1 == nil && value2 == nil { + return true + } + val1 := printer.Print(value1) + val2 := printer.Print(value2) + + return val1 == val2 +} + +// Two types conflict if both types could not apply to a value simultaneously. +// Composite types are ignored as their individual field types will be compared +// later recursively. However List and Non-Null types must match. +func doTypesConflict(type1 Output, type2 Output) bool { + if type1, ok := type1.(*List); ok { + if type2, ok := type2.(*List); ok { + return doTypesConflict(type1.OfType, type2.OfType) + } + return true + } + if type2, ok := type2.(*List); ok { + if type1, ok := type1.(*List); ok { + return doTypesConflict(type1.OfType, type2.OfType) + } + return true + } + if type1, ok := type1.(*NonNull); ok { + if type2, ok := type2.(*NonNull); ok { + return doTypesConflict(type1.OfType, type2.OfType) + } + return true + } + if type2, ok := type2.(*NonNull); ok { + if type1, ok := type1.(*NonNull); ok { + return doTypesConflict(type1.OfType, type2.OfType) + } + return true + } + if IsLeafType(type1) || IsLeafType(type2) { + return type1 != type2 + } + return false +} + +// subfieldConflicts Given a series of Conflicts which occurred between two sub-fields, generate a single Conflict. +func subfieldConflicts(conflicts []conflict, responseName string, ast1 *ast.Field, ast2 *ast.Field) *conflict { + if len(conflicts) > 0 { + conflictReasons := []conflictReason{} + conflictFieldsLeft := []ast.Node{ast1} + conflictFieldsRight := []ast.Node{ast2} + for _, c := range conflicts { + conflictReasons = append(conflictReasons, c.Reason) + conflictFieldsLeft = append(conflictFieldsLeft, c.FieldsLeft...) + conflictFieldsRight = append(conflictFieldsRight, c.FieldsRight...) + } + + return &conflict{ + Reason: conflictReason{ + Name: responseName, + Message: conflictReasons, + }, + FieldsLeft: conflictFieldsLeft, + FieldsRight: conflictFieldsRight, + } + } + return nil +} diff --git a/rules_overlapping_fields_can_be_merged_test.go b/rules_overlapping_fields_can_be_merged_test.go index b38b13a1..7034a858 100644 --- a/rules_overlapping_fields_can_be_merged_test.go +++ b/rules_overlapping_fields_can_be_merged_test.go @@ -74,7 +74,9 @@ func TestValidate_OverlappingFieldsCanBeMerged_SameAliasesWithDifferentFieldTarg fido: nickname } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "fido" conflict because name and nickname are different fields.`, 3, 9, 4, 9), + testutil.RuleError(`Fields "fido" conflict because name and nickname are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 3, 9, 4, 9), }) } func TestValidate_OverlappingFieldsCanBeMerged_SameAliasesAllowedOnNonOverlappingFields(t *testing.T) { @@ -96,7 +98,9 @@ func TestValidate_OverlappingFieldsCanBeMerged_AliasMaskingDirectFieldAccess(t * name } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "name" conflict because nickname and name are different fields.`, 3, 9, 4, 9), + testutil.RuleError(`Fields "name" conflict because nickname and name are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 3, 9, 4, 9), }) } func TestValidate_OverlappingFieldsCanBeMerged_DifferentArgs_SecondAddsAnArgument(t *testing.T) { @@ -106,7 +110,9 @@ func TestValidate_OverlappingFieldsCanBeMerged_DifferentArgs_SecondAddsAnArgumen doesKnowCommand(dogCommand: HEEL) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing arguments.`, 3, 9, 4, 9), + testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing arguments. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 3, 9, 4, 9), }) } func TestValidate_OverlappingFieldsCanBeMerged_DifferentArgs_SecondMissingAnArgument(t *testing.T) { @@ -116,7 +122,9 @@ func TestValidate_OverlappingFieldsCanBeMerged_DifferentArgs_SecondMissingAnArgu doesKnowCommand } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing arguments.`, 3, 9, 4, 9), + testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing arguments. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 3, 9, 4, 9), }) } func TestValidate_OverlappingFieldsCanBeMerged_ConflictingArgs(t *testing.T) { @@ -126,7 +134,9 @@ func TestValidate_OverlappingFieldsCanBeMerged_ConflictingArgs(t *testing.T) { doesKnowCommand(dogCommand: HEEL) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing arguments.`, 3, 9, 4, 9), + testutil.RuleError(`Fields "doesKnowCommand" conflict because they have differing arguments. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 3, 9, 4, 9), }) } func TestValidate_OverlappingFieldsCanBeMerged_AllowDifferentArgsWhereNoConflictIsPossible(t *testing.T) { @@ -156,7 +166,9 @@ func TestValidate_OverlappingFieldsCanBeMerged_EncountersConflictInFragments(t * x: b } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "x" conflict because a and b are different fields.`, 7, 9, 10, 9), + testutil.RuleError(`Fields "x" conflict because a and b are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 7, 9, 10, 9), }) } func TestValidate_OverlappingFieldsCanBeMerged_ReportsEachConflictOnce(t *testing.T) { @@ -183,9 +195,15 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReportsEachConflictOnce(t *testin x: b } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "x" conflict because a and b are different fields.`, 18, 9, 21, 9), - testutil.RuleError(`Fields "x" conflict because a and c are different fields.`, 18, 9, 14, 11), - testutil.RuleError(`Fields "x" conflict because b and c are different fields.`, 21, 9, 14, 11), + testutil.RuleError(`Fields "x" conflict because a and b are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 18, 9, 21, 9), + testutil.RuleError(`Fields "x" conflict because c and a are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 14, 11, 18, 9), + testutil.RuleError(`Fields "x" conflict because c and b are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 14, 11, 21, 9), }) } func TestValidate_OverlappingFieldsCanBeMerged_DeepConflict(t *testing.T) { @@ -199,7 +217,8 @@ func TestValidate_OverlappingFieldsCanBeMerged_DeepConflict(t *testing.T) { } } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Fields "field" conflict because subfields "x" conflict because a and b are different fields.`, + testutil.RuleError(`Fields "field" conflict because subfields "x" conflict because a and b are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 3, 9, 4, 11, 6, 9, @@ -219,9 +238,9 @@ func TestValidate_OverlappingFieldsCanBeMerged_DeepConflictWithMultipleIssues(t } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "field" conflict because subfields "x" conflict because a and b are different fields and `+ - `subfields "y" conflict because c and d are different fields.`, + testutil.RuleError(`Fields "field" conflict because subfields "x" conflict because a and b are different fields and `+ + `subfields "y" conflict because c and d are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 3, 9, 4, 11, 5, 11, @@ -245,9 +264,9 @@ func TestValidate_OverlappingFieldsCanBeMerged_VeryDeepConflict(t *testing.T) { } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "field" conflict because subfields "deepField" conflict because subfields "x" conflict because `+ - `a and b are different fields.`, + testutil.RuleError(`Fields "field" conflict because subfields "deepField" conflict because subfields "x" conflict because `+ + `a and b are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 3, 9, 4, 11, 5, 13, @@ -274,15 +293,101 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReportsDeepConflictToNearestCommo } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "deepField" conflict because subfields "x" conflict because `+ - `a and b are different fields.`, + testutil.RuleError(`Fields "deepField" conflict because subfields "x" conflict because `+ + `a and b are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 4, 11, 5, 13, 7, 11, 8, 13), }) } +func TestValidate_OverlappingFieldsCanBeMerged_ReportsDeepConflictToNearestCommonAncestorInFragments(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` + { + field { + ...F + } + field { + ...F + } + } + fragment F on T { + deepField { + deeperField { + x: a + } + deeperField { + x: b + } + }, + deepField { + deeperField { + y + } + } + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Fields "deeperField" conflict because subfields "x" conflict because `+ + `a and b are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 12, 11, + 13, 13, + 15, 11, + 16, 13), + }) +} +func TestValidate_OverlappingFieldsCanBeMerged_ReportsDeepConflictInNestedFragments(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` + { + field { + ...F + } + field { + ...I + } + } + fragment F on T { + x: a + ...G + } + fragment G on T { + y: c + } + fragment I on T { + y: d + ...J + } + fragment J on T { + x: b + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Fields "field" conflict because `+ + `subfields "x" conflict because a and b are different fields and `+ + `subfields "y" conflict because c and d are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 3, 9, + 11, 9, + 15, 9, + 6, 9, + 22, 9, + 18, 9), + }) +} +func TestValidate_OverlappingFieldsCanBeMerged_IgnoresUnknownFragments(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.OverlappingFieldsCanBeMergedRule, ` + { + field + ...Unknown + ...Known + } + + fragment Known on T { + field + ...OtherUnknown + } + `) +} var someBoxInterface *graphql.Interface var stringBoxObject *graphql.Object @@ -486,8 +591,8 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Conf } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "scalar" conflict because they return conflicting types Int and String!.`, + testutil.RuleError(`Fields "scalar" conflict because they return conflicting types Int and String!. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 5, 15, 8, 15), }) @@ -526,12 +631,66 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Disa } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "scalar" conflict because they return conflicting types Int and String.`, + testutil.RuleError(`Fields "scalar" conflict because they return conflicting types Int and String. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 5, 15, 8, 15), }) } +func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_ReportsCorrectlyWhenANonExclusiveFollosAnExclusive(t *testing.T) { + testutil.ExpectFailsRuleWithSchema(t, &schema, graphql.OverlappingFieldsCanBeMergedRule, ` + { + someBox { + ... on IntBox { + deepBox { + ...X + } + } + } + someBox { + ... on StringBox { + deepBox { + ...Y + } + } + } + memoed: someBox { + ... on IntBox { + deepBox { + ...X + } + } + } + memoed: someBox { + ... on StringBox { + deepBox { + ...Y + } + } + } + other: someBox { + ...X + } + other: someBox { + ...Y + } + } + fragment X on SomeBox { + scalar + } + fragment Y on SomeBox { + scalar: unrelatedField + } + `, []gqlerrors.FormattedError{ + testutil.RuleError(`Fields "other" conflict because subfields "scalar" conflict `+ + `because scalar and unrelatedField are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, + 31, 11, + 39, 11, + 34, 11, + 42, 11), + }) +} func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_DisallowsDifferingReturnTypeNullabilityDespiteNoOverlap(t *testing.T) { testutil.ExpectFailsRuleWithSchema(t, &schema, graphql.OverlappingFieldsCanBeMergedRule, ` { @@ -545,8 +704,8 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Disa } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "scalar" conflict because they return conflicting types String! and String.`, + testutil.RuleError(`Fields "scalar" conflict because they return conflicting types String! and String. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 5, 15, 8, 15), }) @@ -568,8 +727,8 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Disa } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "box" conflict because they return conflicting types [StringBox] and StringBox.`, + testutil.RuleError(`Fields "box" conflict because they return conflicting types [StringBox] and StringBox. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 5, 15, 10, 15), }) @@ -590,8 +749,8 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Disa } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "box" conflict because they return conflicting types StringBox and [StringBox].`, + testutil.RuleError(`Fields "box" conflict because they return conflicting types StringBox and [StringBox]. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 5, 15, 10, 15), }) @@ -614,8 +773,8 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Disa } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "val" conflict because scalar and unrelatedField are different fields.`, + testutil.RuleError(`Fields "val" conflict because scalar and unrelatedField are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 6, 17, 7, 17), }) @@ -637,8 +796,8 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Disa } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "box" conflict because subfields "scalar" conflict because they return conflicting types String and Int.`, + testutil.RuleError(`Fields "box" conflict because subfields "scalar" conflict because they return conflicting types String and Int. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 5, 15, 6, 17, 10, 15, @@ -704,15 +863,15 @@ func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_Comp } } `, []gqlerrors.FormattedError{ - testutil.RuleError( - `Fields "edges" conflict because subfields "node" conflict because subfields "id" conflict because `+ - `id and name are different fields.`, - 14, 11, - 15, 13, - 16, 15, + testutil.RuleError(`Fields "edges" conflict because subfields "node" conflict because subfields "id" conflict because `+ + `name and id are different fields. `+ + `Use different aliases on the fields to fetch both if this was intentional.`, 5, 13, 6, 15, - 7, 17), + 7, 17, + 14, 11, + 15, 13, + 16, 15), }) } func TestValidate_OverlappingFieldsCanBeMerged_ReturnTypesMustBeUnambiguous_IgnoresUnknownTypes(t *testing.T) { diff --git a/schema-kitchen-sink.graphql b/schema-kitchen-sink.graphql index efc1b469..582d94ae 100644 --- a/schema-kitchen-sink.graphql +++ b/schema-kitchen-sink.graphql @@ -14,29 +14,54 @@ type Foo implements Bar { six(argument: InputType = {key: "value"}): Type } +type AnnotatedObject @onObject(arg: "value") { + annotatedField(arg: Type = "default" @onArg): Type @onField +} + interface Bar { one: Type four(argument: String = "string"): String } +interface AnnotatedInterface @onInterface { + annotatedField(arg: Type @onArg): Type @onField +} + union Feed = Story | Article | Advert +union AnnotatedUnion @onUnion = A | B + scalar CustomScalar +scalar AnnotatedScalar @onScalar + enum Site { DESKTOP MOBILE } +enum AnnotatedEnum @onEnum { + ANNOTATED_VALUE @onEnumValue + OTHER_VALUE +} + input InputType { key: String! answer: Int = 42 } +input AnnotatedInput @onInputObjectType { + annotatedField: Type @onField +} + extend type Foo { seven(argument: [String]): Type } +extend type Foo @onType {} + +type NoFields {} + directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT directive @include(if: Boolean!) diff --git a/schema.go b/schema.go index f3a4eaff..420a8e60 100644 --- a/schema.go +++ b/schema.go @@ -14,16 +14,26 @@ type SchemaConfig struct { type TypeMap map[string]Type -//Schema Definition -//A Schema is created by supplying the root types of each type of operation, -//query, mutation (optional) and subscription (optional). A schema definition is then supplied to the -//validator and executor. -//Example: -// myAppSchema, err := NewSchema(SchemaConfig({ -// Query: MyAppQueryRootType, -// Mutation: MyAppMutationRootType, -// Subscription: MyAppSubscriptionRootType, -// }); +// Schema Definition +// A Schema is created by supplying the root types of each type of operation, +// query, mutation (optional) and subscription (optional). A schema definition is then supplied to the +// validator and executor. +// Example: +// myAppSchema, err := NewSchema(SchemaConfig({ +// Query: MyAppQueryRootType, +// Mutation: MyAppMutationRootType, +// Subscription: MyAppSubscriptionRootType, +// }); +// Note: If an array of `directives` are provided to GraphQLSchema, that will be +// the exact list of directives represented and allowed. If `directives` is not +// provided then a default set of the specified directives (e.g. @include and +// @skip) will be used. If you wish to provide *additional* directives to these +// specified directives, you must explicitly declare them. Example: +// +// const MyAppSchema = new GraphQLSchema({ +// ... +// directives: specifiedDirectives.concat([ myCustomDirective ]), +// }) type Schema struct { typeMap TypeMap directives []*Directive @@ -57,13 +67,10 @@ func NewSchema(config SchemaConfig) (Schema, error) { schema.mutationType = config.Mutation schema.subscriptionType = config.Subscription - // Provide `@include() and `@skip()` directives by default. + // Provide specified directives (e.g. @include and @skip) by default. schema.directives = config.Directives if len(schema.directives) == 0 { - schema.directives = []*Directive{ - IncludeDirective, - SkipDirective, - } + schema.directives = SpecifiedDirectives } // Ensure directive definitions are error-free for _, dir := range schema.directives { @@ -84,8 +91,8 @@ func NewSchema(config SchemaConfig) (Schema, error) { if schema.SubscriptionType() != nil { initialTypes = append(initialTypes, schema.SubscriptionType()) } - if schemaType != nil { - initialTypes = append(initialTypes, schemaType) + if SchemaType != nil { + initialTypes = append(initialTypes, SchemaType) } for _, ttype := range config.Types { @@ -453,10 +460,8 @@ func isEqualType(typeA Type, typeB Type) bool { return false } -/** - * Provided a type and a super type, return true if the first type is either - * equal or a subset of the second super type (covariant). - */ +// isTypeSubTypeOf Provided a type and a super type, return true if the first type is either +// equal or a subset of the second super type (covariant). func isTypeSubTypeOf(schema *Schema, maybeSubType Type, superType Type) bool { // Equivalent type is a valid subtype if maybeSubType == superType { diff --git a/suggested_list_internal_test.go b/suggested_list_internal_test.go new file mode 100644 index 00000000..52e04ae3 --- /dev/null +++ b/suggested_list_internal_test.go @@ -0,0 +1,28 @@ +package graphql + +import ( + "reflect" + "testing" +) + +func TestSuggestionList_ReturnsResultsWhenInputIsEmpty(t *testing.T) { + expected := []string{"a"} + result := suggestionList("", []string{"a"}) + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v, got: %v", expected, result) + } +} +func TestSuggestionList_ReturnsEmptyArrayWhenThereAreNoOptions(t *testing.T) { + expected := []string{} + result := suggestionList("input", []string{}) + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v, got: %v", expected, result) + } +} +func TestSuggestionList_ReturnsOptionsSortedBasedOnSimilarity(t *testing.T) { + expected := []string{"abc", "ab"} + result := suggestionList("abc", []string{"a", "ab", "abc"}) + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v, got: %v", expected, result) + } +} diff --git a/testutil/introspection_query.go b/testutil/introspection_query.go index d92b81e2..a8914154 100644 --- a/testutil/introspection_query.go +++ b/testutil/introspection_query.go @@ -76,6 +76,22 @@ var IntrospectionQuery = ` ofType { kind name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + } + } + } + } } } } diff --git a/testutil/rules_test_harness.go b/testutil/rules_test_harness.go index 809a1eff..40a8d3d5 100644 --- a/testutil/rules_test_harness.go +++ b/testutil/rules_test_harness.go @@ -182,10 +182,6 @@ func init() { dogType, catType, }, - ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { - // not used for validation - return nil - }, }) var intelligentInterface = graphql.NewInterface(graphql.InterfaceConfig{ Name: "Intelligent", @@ -459,12 +455,80 @@ func init() { schema, err := graphql.NewSchema(graphql.SchemaConfig{ Query: queryRoot, Directives: []*graphql.Directive{ + graphql.IncludeDirective, + graphql.SkipDirective, graphql.NewDirective(graphql.DirectiveConfig{ - Name: "operationOnly", + Name: "onQuery", Locations: []string{graphql.DirectiveLocationQuery}, }), - graphql.IncludeDirective, - graphql.SkipDirective, + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onMutation", + Locations: []string{graphql.DirectiveLocationMutation}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onSubscription", + Locations: []string{graphql.DirectiveLocationSubscription}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onField", + Locations: []string{graphql.DirectiveLocationField}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onFragmentDefinition", + Locations: []string{graphql.DirectiveLocationFragmentDefinition}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onFragmentSpread", + Locations: []string{graphql.DirectiveLocationFragmentSpread}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onInlineFragment", + Locations: []string{graphql.DirectiveLocationInlineFragment}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onSchema", + Locations: []string{graphql.DirectiveLocationSchema}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onScalar", + Locations: []string{graphql.DirectiveLocationScalar}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onObject", + Locations: []string{graphql.DirectiveLocationObject}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onFieldDefinition", + Locations: []string{graphql.DirectiveLocationFieldDefinition}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onArgumentDefinition", + Locations: []string{graphql.DirectiveLocationArgumentDefinition}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onInterface", + Locations: []string{graphql.DirectiveLocationInterface}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onUnion", + Locations: []string{graphql.DirectiveLocationUnion}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onEnum", + Locations: []string{graphql.DirectiveLocationEnum}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onEnumValue", + Locations: []string{graphql.DirectiveLocationEnumValue}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onInputObject", + Locations: []string{graphql.DirectiveLocationInputObject}, + }), + graphql.NewDirective(graphql.DirectiveConfig{ + Name: "onInputFieldDefinition", + Locations: []string{graphql.DirectiveLocationInputFieldDefinition}, + }), }, Types: []graphql.Type{ catType, diff --git a/type_comparators_test.go b/type_comparators_internal_test.go similarity index 100% rename from type_comparators_test.go rename to type_comparators_internal_test.go diff --git a/validator.go b/validator.go index 7baf109f..73c213eb 100644 --- a/validator.go +++ b/validator.go @@ -95,7 +95,7 @@ type ValidationContext struct { variableUsages map[HasSelectionSet][]*VariableUsage recursiveVariableUsages map[*ast.OperationDefinition][]*VariableUsage recursivelyReferencedFragments map[*ast.OperationDefinition][]*ast.FragmentDefinition - fragmentSpreads map[HasSelectionSet][]*ast.FragmentSpread + fragmentSpreads map[*ast.SelectionSet][]*ast.FragmentSpread } func NewValidationContext(schema *Schema, astDoc *ast.Document, typeInfo *TypeInfo) *ValidationContext { @@ -107,7 +107,7 @@ func NewValidationContext(schema *Schema, astDoc *ast.Document, typeInfo *TypeIn variableUsages: map[HasSelectionSet][]*VariableUsage{}, recursiveVariableUsages: map[*ast.OperationDefinition][]*VariableUsage{}, recursivelyReferencedFragments: map[*ast.OperationDefinition][]*ast.FragmentDefinition{}, - fragmentSpreads: map[HasSelectionSet][]*ast.FragmentSpread{}, + fragmentSpreads: map[*ast.SelectionSet][]*ast.FragmentSpread{}, } } @@ -146,13 +146,13 @@ func (ctx *ValidationContext) Fragment(name string) *ast.FragmentDefinition { f, _ := ctx.fragments[name] return f } -func (ctx *ValidationContext) FragmentSpreads(node HasSelectionSet) []*ast.FragmentSpread { +func (ctx *ValidationContext) FragmentSpreads(node *ast.SelectionSet) []*ast.FragmentSpread { if spreads, ok := ctx.fragmentSpreads[node]; ok && spreads != nil { return spreads } spreads := []*ast.FragmentSpread{} - setsToVisit := []*ast.SelectionSet{node.GetSelectionSet()} + setsToVisit := []*ast.SelectionSet{node} for { if len(setsToVisit) == 0 { @@ -189,14 +189,14 @@ func (ctx *ValidationContext) RecursivelyReferencedFragments(operation *ast.Oper fragments := []*ast.FragmentDefinition{} collectedNames := map[string]bool{} - nodesToVisit := []HasSelectionSet{operation} + nodesToVisit := []*ast.SelectionSet{operation.SelectionSet} for { if len(nodesToVisit) == 0 { break } - var node HasSelectionSet + var node *ast.SelectionSet node, nodesToVisit = nodesToVisit[len(nodesToVisit)-1], nodesToVisit[:len(nodesToVisit)-1] spreads := ctx.FragmentSpreads(node) @@ -210,7 +210,7 @@ func (ctx *ValidationContext) RecursivelyReferencedFragments(operation *ast.Oper fragment := ctx.Fragment(fragName) if fragment != nil { fragments = append(fragments, fragment) - nodesToVisit = append(nodesToVisit, fragment) + nodesToVisit = append(nodesToVisit, fragment.SelectionSet) } } diff --git a/validator_test.go b/validator_test.go index 67b7a3dd..0dfb1dec 100644 --- a/validator_test.go +++ b/validator_test.go @@ -1,6 +1,7 @@ package graphql_test import ( + "reflect" "testing" "github.com/graphql-go/graphql" @@ -10,7 +11,6 @@ import ( "github.com/graphql-go/graphql/language/parser" "github.com/graphql-go/graphql/language/source" "github.com/graphql-go/graphql/testutil" - "reflect" ) func expectValid(t *testing.T, schema *graphql.Schema, queryString string) { @@ -74,19 +74,19 @@ func TestValidator_SupportsFullValidation_ValidatesUsingACustomTypeInfo(t *testi expectedErrors := []gqlerrors.FormattedError{ { - Message: "Cannot query field \"catOrDog\" on type \"QueryRoot\".", + Message: `Cannot query field "catOrDog" on type "QueryRoot". Did you mean "catOrDog"?`, Locations: []location.SourceLocation{ {Line: 3, Column: 9}, }, }, { - Message: "Cannot query field \"furColor\" on type \"Cat\".", + Message: `Cannot query field "furColor" on type "Cat". Did you mean "furColor"?`, Locations: []location.SourceLocation{ {Line: 5, Column: 13}, }, }, { - Message: "Cannot query field \"isHousetrained\" on type \"Dog\".", + Message: `Cannot query field "isHousetrained" on type "Dog". Did you mean "isHousetrained"?`, Locations: []location.SourceLocation{ {Line: 8, Column: 13}, },