Skip to content

Commit c3fb17d

Browse files
authored
Merge pull request #756 from graphql-java-kickstart/334-check-subscription-while-parsing
Validate subscription data resolver during schema parsing
2 parents 8b6971f + a0ba473 commit c3fb17d

File tree

3 files changed

+72
-7
lines changed

3 files changed

+72
-7
lines changed

src/main/kotlin/graphql/kickstart/tools/SchemaParserOptions.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ data class SchemaParserOptions internal constructor(
163163
GenericWrapper(CompletableFuture::class, 0),
164164
GenericWrapper(CompletionStage::class, 0),
165165
GenericWrapper(Publisher::class, 0),
166-
GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel, _ ->
166+
GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel ->
167167
publish(coroutineContextProvider.provide()) {
168168
try {
169169
for (item in receiveChannel) {

src/main/kotlin/graphql/kickstart/tools/resolver/FieldResolverScanner.kt

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@ package graphql.kickstart.tools.resolver
22

33
import graphql.GraphQLContext
44
import graphql.Scalars
5+
import graphql.kickstart.tools.GraphQLSubscriptionResolver
56
import graphql.kickstart.tools.ResolverInfo
67
import graphql.kickstart.tools.RootResolverInfo
78
import graphql.kickstart.tools.SchemaParserOptions
89
import graphql.kickstart.tools.util.*
910
import graphql.language.FieldDefinition
1011
import graphql.language.TypeName
1112
import graphql.schema.DataFetchingEnvironment
13+
import kotlinx.coroutines.channels.ReceiveChannel
1214
import org.apache.commons.lang3.ClassUtils
1315
import org.apache.commons.lang3.reflect.FieldUtils
16+
import org.reactivestreams.Publisher
1417
import org.slf4j.LoggerFactory
1518
import java.lang.reflect.AccessibleObject
1619
import java.lang.reflect.Method
@@ -86,7 +89,7 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
8689
}
8790

8891
private fun findResolverMethod(field: FieldDefinition, search: Search): Method? {
89-
val methods = getAllMethods(search.type)
92+
val methods = getAllMethods(search)
9093
val argumentCount = field.inputValueDefinitions.size + if (search.requiredFirstParameterType != null) 1 else 0
9194
val name = field.name
9295

@@ -109,10 +112,11 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
109112
}
110113
}
111114

112-
private fun getAllMethods(type: JavaType): List<Method> {
113-
val declaredMethods = type.unwrap().declaredNonProxyMethods
114-
val superClassesMethods = ClassUtils.getAllSuperclasses(type.unwrap()).flatMap { it.methods.toList() }
115-
val interfacesMethods = ClassUtils.getAllInterfaces(type.unwrap()).flatMap { it.methods.toList() }
115+
private fun getAllMethods(search: Search): List<Method> {
116+
val type = search.type.unwrap()
117+
val declaredMethods = type.declaredNonProxyMethods
118+
val superClassesMethods = ClassUtils.getAllSuperclasses(type).flatMap { it.methods.toList() }
119+
val interfacesMethods = ClassUtils.getAllInterfaces(type).flatMap { it.methods.toList() }
116120

117121
return (declaredMethods + superClassesMethods + interfacesMethods)
118122
.asSequence()
@@ -121,9 +125,26 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
121125
// discard any methods that are coming off the root of the class hierarchy
122126
// to avoid issues with duplicate method declarations
123127
.filter { it.declaringClass != Object::class.java }
128+
// subscription resolvers must return a publisher
129+
.filter { search.source !is GraphQLSubscriptionResolver || resolverMethodReturnsPublisher(it) }
124130
.toList()
125131
}
126132

133+
private fun resolverMethodReturnsPublisher(method: Method) =
134+
method.returnType.isAssignableFrom(Publisher::class.java) || receiveChannelToPublisherWrapper(method)
135+
136+
private fun receiveChannelToPublisherWrapper(method: Method) =
137+
method.returnType.isAssignableFrom(ReceiveChannel::class.java)
138+
&& options.genericWrappers.any { wrapper ->
139+
val isReceiveChannelWrapper = wrapper.type == method.returnType
140+
val hasPublisherTransformer = wrapper
141+
.transformer.javaClass
142+
.declaredMethods
143+
.filter { it.name == "invoke" }
144+
.any { it.returnType.isAssignableFrom(Publisher::class.java) }
145+
isReceiveChannelWrapper && hasPublisherTransformer
146+
}
147+
127148
private fun isBoolean(type: GraphQLLangType) = type.unwrap().let { it is TypeName && it.name == Scalars.GraphQLBoolean.name }
128149

129150
private fun verifyMethodArguments(method: Method, requiredCount: Int, search: Search): Boolean {
@@ -166,14 +187,18 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
166187
private fun getMissingFieldMessage(field: FieldDefinition, searches: List<Search>, scannedProperties: Boolean): String {
167188
val signatures = mutableListOf("")
168189
val isBoolean = isBoolean(field.type)
190+
var isSubscription = false
169191

170192
searches.forEach { search ->
171193
signatures.addAll(getMissingMethodSignatures(field, search, isBoolean, scannedProperties))
194+
isSubscription = isSubscription || search.source is GraphQLSubscriptionResolver
172195
}
173196

174197
val sourceName = if (field.sourceLocation != null && field.sourceLocation.sourceName != null) field.sourceLocation.sourceName else "<unknown>"
175198
val sourceLocation = if (field.sourceLocation != null) "$sourceName:${field.sourceLocation.line}" else "<unknown>"
176-
return "No method${if (scannedProperties) " or field" else ""} found as defined in schema $sourceLocation with any of the following signatures (with or without one of $allowedLastArgumentTypes as the last argument), in priority order:\n${signatures.joinToString("\n ")}"
199+
return "No method${if (scannedProperties) " or field" else ""} found as defined in schema $sourceLocation with any of the following signatures " +
200+
"(with or without one of $allowedLastArgumentTypes as the last argument), in priority order:\n${signatures.joinToString("\n ")}" +
201+
if (isSubscription) "\n\nNote that a Subscription data fetcher must return a Publisher of events" else ""
177202
}
178203

179204
private fun getMissingMethodSignatures(field: FieldDefinition, search: Search, isBoolean: Boolean, scannedProperties: Boolean): List<String> {

src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,4 +662,44 @@ class SchemaParserTest {
662662
}
663663
}
664664
}
665+
666+
@Test
667+
fun `parser should verify subscription resolver return type`() {
668+
val error = assertThrows(FieldResolverError::class.java) {
669+
SchemaParser.newParser()
670+
.schemaString(
671+
"""
672+
type Subscription {
673+
onItemCreated: Int!
674+
}
675+
676+
type Query {
677+
test: String
678+
}
679+
"""
680+
)
681+
.resolvers(
682+
Subscription(),
683+
object : GraphQLQueryResolver { fun test() = "test" }
684+
)
685+
.build()
686+
.makeExecutableSchema()
687+
}
688+
689+
val expected = """
690+
No method or field found as defined in schema <unknown>:3 with any of the following signatures (with or without one of [interface graphql.schema.DataFetchingEnvironment, class graphql.GraphQLContext] as the last argument), in priority order:
691+
692+
graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.onItemCreated()
693+
graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.getOnItemCreated()
694+
graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.onItemCreated
695+
696+
Note that a Subscription data fetcher must return a Publisher of events
697+
""".trimIndent()
698+
699+
assertEquals(error.message, expected)
700+
}
701+
702+
class Subscription : GraphQLSubscriptionResolver {
703+
fun onItemCreated(env: DataFetchingEnvironment) = env.hashCode()
704+
}
665705
}

0 commit comments

Comments
 (0)