Skip to content

Commit 18a5595

Browse files
committed
Nested sorting on fields (GH #3)
This commit add the following features: * sorting on multiple properties * sorting on fields * augmentation-tests provides a diff on failure that can be viewed in IntelliJ Bugfix: * The formatted field for temporal is now a string rather than an object
1 parent 4792d3f commit 18a5595

File tree

6 files changed

+291
-122
lines changed

6 files changed

+291
-122
lines changed

src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import org.neo4j.graphql.handler.projection.ProjectionBase
1212
import org.neo4j.graphql.handler.relation.CreateRelationHandler
1313
import org.neo4j.graphql.handler.relation.CreateRelationTypeHandler
1414
import org.neo4j.graphql.handler.relation.DeleteRelationHandler
15+
import java.util.concurrent.ConcurrentHashMap
1516

1617
object SchemaBuilder {
1718
private const val MUTATION = "Mutation"
@@ -57,7 +58,7 @@ object SchemaBuilder {
5758

5859
val handler = getHandler(config)
5960

60-
var targetSchema = augmentSchema(sourceSchema, handler)
61+
var targetSchema = augmentSchema(sourceSchema, handler, config)
6162
targetSchema = addDataFetcher(targetSchema, dataFetchingInterceptor, handler)
6263
return targetSchema
6364
}
@@ -82,8 +83,8 @@ object SchemaBuilder {
8283
return handler
8384
}
8485

85-
private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List<AugmentationHandler>): GraphQLSchema {
86-
val types = sourceSchema.typeMap.toMutableMap()
86+
private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List<AugmentationHandler>, config: SchemaConfig): GraphQLSchema {
87+
val types = sourceSchema.typeMap.toMap(ConcurrentHashMap())
8788
val env = BuildingEnv(types)
8889

8990
types.values
@@ -106,11 +107,11 @@ object SchemaBuilder {
106107
builder.clearFields().clearInterfaces()
107108
// to prevent duplicated types in schema
108109
sourceType.interfaces.forEach { builder.withInterface(GraphQLTypeReference(it.name)) }
109-
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f)) }
110+
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env, config)) }
110111
}
111112
sourceType is GraphQLInterfaceType -> sourceType.transform { builder ->
112113
builder.clearFields()
113-
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f)) }
114+
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env, config)) }
114115
}
115116
else -> sourceType
116117
}
@@ -125,27 +126,29 @@ object SchemaBuilder {
125126
.build()
126127
}
127128

128-
private fun enhanceRelations(fd: GraphQLFieldDefinition): GraphQLFieldDefinition {
129-
return fd.transform {
129+
private fun enhanceRelations(fd: GraphQLFieldDefinition, env: BuildingEnv, config: SchemaConfig): GraphQLFieldDefinition {
130+
return fd.transform { fieldBuilder ->
130131
// to prevent duplicated types in schema
131-
it.type(fd.type.ref() as GraphQLOutputType)
132+
fieldBuilder.type(fd.type.ref() as GraphQLOutputType)
132133

133134
if (!fd.isRelationship() || !fd.type.isList()) {
134135
return@transform
135136
}
136137

137138
if (fd.getArgument(ProjectionBase.FIRST) == null) {
138-
it.argument { a -> a.name(ProjectionBase.FIRST).type(Scalars.GraphQLInt) }
139+
fieldBuilder.argument { a -> a.name(ProjectionBase.FIRST).type(Scalars.GraphQLInt) }
139140
}
140141
if (fd.getArgument(ProjectionBase.OFFSET) == null) {
141-
it.argument { a -> a.name(ProjectionBase.OFFSET).type(Scalars.GraphQLInt) }
142+
fieldBuilder.argument { a -> a.name(ProjectionBase.OFFSET).type(Scalars.GraphQLInt) }
143+
}
144+
if (fd.getArgument(ProjectionBase.ORDER_BY) == null && fd.type.isList()) {
145+
(fd.type.inner() as? GraphQLFieldsContainer)?.let { fieldType ->
146+
env.addOrdering(fieldType)?.let { orderingTypeName ->
147+
val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName)))
148+
fieldBuilder.argument { a -> a.name(ProjectionBase.ORDER_BY).type(orderType) }
149+
}
150+
}
142151
}
143-
// TODO implement ordering
144-
// if (fd.getArgument(ProjectionBase.ORDER_BY) == null) {
145-
// val typeName = fd.type.name()!!
146-
// val orderingType = addOrdering(typeName, metaProvider.getNodeType(typeName)!!.fieldDefinitions().filter { it.type.isScalar() })
147-
// it.argument { a -> a.name(ProjectionBase.ORDER_BY).type(orderingType) }
148-
// }
149152
}
150153
}
151154

@@ -205,4 +208,4 @@ object SchemaBuilder {
205208
typeDefinitionRegistry.add(inputType)
206209
return inputName
207210
}
208-
}
211+
}

src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class QueryHandler private constructor(
3434
.argument(input(OFFSET, Scalars.GraphQLInt))
3535
.type(GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLTypeReference(type.name)))))
3636
if (orderingTypeName != null) {
37-
builder.argument(input(ORDER_BY, GraphQLTypeReference(orderingTypeName)))
37+
val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName)))
38+
builder.argument(input(ORDER_BY, orderType))
3839
}
3940
val def = builder.build()
4041
buildingEnv.addOperation(QUERY, def)
@@ -102,4 +103,4 @@ class QueryHandler private constructor(
102103
|RETURN ${mapProjection.query} AS $variable$ordering${skipLimit.query}""".trimMargin(),
103104
(where.params + mapProjection.params + skipLimit.params))
104105
}
105-
}
106+
}

src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,30 @@ open class ProjectionBase {
1616
}
1717

1818
fun orderBy(variable: String, args: MutableList<Argument>): String {
19+
val values = getOrderByArgs(args)
20+
if (values.isEmpty()) {
21+
return ""
22+
}
23+
return " ORDER BY " + values.joinToString(", ", transform = { (property, direction) -> "$variable.$property $direction" })
24+
}
25+
26+
private fun getOrderByArgs(args: MutableList<Argument>): List<Pair<String, Sort>> {
1927
val arg = args.find { it.name == ORDER_BY }
20-
val values = arg?.value?.let { it ->
21-
when (it) {
22-
is ArrayValue -> it.values.map { it.toJavaValue().toString() }
23-
is EnumValue -> listOf(it.name)
24-
is StringValue -> listOf(it.value)
25-
else -> null
28+
return arg?.value
29+
?.let { it ->
30+
when (it) {
31+
is ArrayValue -> it.values.map { it.toJavaValue().toString() }
32+
is EnumValue -> listOf(it.name)
33+
is StringValue -> listOf(it.value)
34+
else -> null
35+
}
2636
}
27-
}
28-
@Suppress("SimplifiableCallChain")
29-
return if (values == null) ""
30-
else " ORDER BY " + values
31-
.map { it.split("_") }
32-
.map { "$variable.${it[0]} ${it[1].toUpperCase()}" }
33-
.joinToString(", ")
37+
?.map {
38+
val index = it.lastIndexOf('_')
39+
val property = it.substring(0, index)
40+
val direction = Sort.valueOf(it.substring(index + 1).toUpperCase())
41+
property to direction
42+
} ?: emptyList()
3443
}
3544

3645
fun where(variable: String, fieldDefinition: GraphQLFieldDefinition, type: GraphQLFieldsContainer, arguments: List<Argument>, field: Field): Cypher {
@@ -133,8 +142,8 @@ open class ProjectionBase {
133142
return predicates.values + defaults
134143
}
135144

136-
fun projectFields(variable: String, field: Field, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?): Cypher {
137-
val queries = projectSelectionSet(variable, field.selectionSet, nodeType, env, variableSuffix)
145+
fun projectFields(variable: String, field: Field, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?, neo4jFieldsToPass: Set<String> = emptySet()): Cypher {
146+
val queries = projectSelection(variable, field.selectionSet.selections, nodeType, env, variableSuffix, neo4jFieldsToPass)
138147
@Suppress("SimplifiableCallChain")
139148
val projection = queries
140149
.map { it.query }
@@ -145,18 +154,18 @@ open class ProjectionBase {
145154
return Cypher("$variable $projection", params)
146155
}
147156

148-
private fun projectSelectionSet(variable: String, selectionSet: SelectionSet, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?): List<Cypher> {
157+
private fun projectSelection(variable: String, selection: List<Selection<*>>, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?, neo4jFieldsToPass: Set<String> = emptySet()): List<Cypher> {
149158
// TODO just render fragments on valid types (Labels) by using cypher like this:
150159
// apoc.map.mergeList([
151160
// a{.name},
152161
// CASE WHEN a:Location THEN a { .foo } ELSE {} END
153162
// ])
154163
var hasTypeName = false
155-
val projections = selectionSet.selections.flatMapTo(mutableListOf<Cypher>()) {
164+
val projections = selection.flatMapTo(mutableListOf<Cypher>()) {
156165
when (it) {
157166
is Field -> {
158167
hasTypeName = hasTypeName || (it.name == TYPE_NAME)
159-
listOf(projectField(variable, it, nodeType, env, variableSuffix))
168+
listOf(projectField(variable, it, nodeType, env, variableSuffix, neo4jFieldsToPass))
160169
}
161170
is InlineFragment -> projectInlineFragment(variable, it, env, variableSuffix)
162171
is FragmentSpread -> projectNamedFragments(variable, it, env, variableSuffix)
@@ -173,7 +182,7 @@ open class ProjectionBase {
173182
return projections
174183
}
175184

176-
private fun projectField(variable: String, field: Field, type: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?): Cypher {
185+
private fun projectField(variable: String, field: Field, type: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?, neo4jFieldsToPass: Set<String> = emptySet()): Cypher {
177186
if (field.name == TYPE_NAME) {
178187
return if (type.isRelationType()) {
179188
Cypher("${field.aliasOrName()}: '${type.name}'")
@@ -199,7 +208,11 @@ open class ProjectionBase {
199208
} ?: when {
200209
isObjectField -> {
201210
val patternComprehensions = if (fieldDefinition.isNeo4jType()) {
202-
projectNeo4jObjectType(variable, field)
211+
if (neo4jFieldsToPass.contains(fieldDefinition.innerName())) {
212+
Cypher(variable + "." + fieldDefinition.propertyName().quote())
213+
} else {
214+
projectNeo4jObjectType(variable, field)
215+
}
203216
} else {
204217
projectRelationship(variable, field, fieldDefinition, type, env, variableSuffix)
205218
}
@@ -223,7 +236,7 @@ open class ProjectionBase {
223236
.filterIsInstance<Field>()
224237
.map {
225238
val value = when (it.name) {
226-
NEO4j_FORMATTED_PROPERTY_KEY -> "$variable.${field.name}"
239+
NEO4j_FORMATTED_PROPERTY_KEY -> "toString($variable.${field.name})"
227240
else -> "$variable.${field.name}.${it.name}"
228241
}
229242
"${it.name}: $value"
@@ -259,7 +272,7 @@ open class ProjectionBase {
259272
val fragmentType = env.graphQLSchema.getType(fragmentTypeName) as? GraphQLFieldsContainer ?: return emptyList()
260273
// these are the nested fields of the fragment
261274
// it could be that we have to adapt the variable name too, and perhaps add some kind of rename
262-
return projectSelectionSet(variable, selectionSet, fragmentType, env, variableSuffix)
275+
return projectSelection(variable, selectionSet.selections, fragmentType, env, variableSuffix)
263276
}
264277

265278

@@ -329,9 +342,27 @@ open class ProjectionBase {
329342
val relPattern = if (isRelFromType) "$childVariable:${relInfo.relType}" else ":${relInfo.relType}"
330343

331344
val where = where(childVariable, fieldDefinition, nodeType, propertyArguments(field), field)
332-
val fieldProjection = projectFields(childVariable, field, nodeType, env, variableSuffix)
333345

334-
val comprehension = "[($variable)$inArrow-[$relPattern]-$outArrow($endNodePattern)${where.query} | ${fieldProjection.query}]"
346+
val orderBy = getOrderByArgs(field.arguments)
347+
val sortByNeo4jTypeFields = orderBy
348+
.filter { (property, _) -> nodeType.getFieldDefinition(property)?.isNeo4jType() == true }
349+
.map { (property, _) -> property }
350+
.toSet()
351+
352+
val fieldProjection = projectFields(childVariable, field, nodeType, env, variableSuffix, sortByNeo4jTypeFields)
353+
var comprehension = "[($variable)$inArrow-[$relPattern]-$outArrow($endNodePattern)${where.query} | ${fieldProjection.query}]"
354+
if (orderBy.isNotEmpty()) {
355+
val sortArgs = orderBy.joinToString(", ", transform = { (property, direction) -> if (direction == Sort.ASC) "'^$property'" else "'$property'" })
356+
comprehension = "apoc.coll.sortMulti($comprehension, [$sortArgs])"
357+
if (sortByNeo4jTypeFields.isNotEmpty()) {
358+
val neo4jFiledSelection = field.selectionSet.selections
359+
.filter { selection -> sortByNeo4jTypeFields.contains((selection as? Field)?.name) }
360+
val deferredProjection = projectSelection("sortedElement", neo4jFiledSelection, nodeType, env, variableSuffix)
361+
.map { cypher -> cypher.query }
362+
.joinNonEmpty(", ")
363+
comprehension = "[sortedElement IN $comprehension | sortedElement { .*, $deferredProjection }]"
364+
}
365+
}
335366
val skipLimit = SkipLimit(childVariable, field.arguments)
336367
val slice = skipLimit.slice(fieldType.isList())
337368
return Cypher(comprehension + slice.query, (where.params + fieldProjection.params + slice.params))
@@ -388,4 +419,9 @@ open class ProjectionBase {
388419
}
389420
}
390421
}
391-
}
422+
423+
enum class Sort {
424+
ASC,
425+
DESC
426+
}
427+
}

src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import org.junit.jupiter.api.DynamicTest
1515
import org.neo4j.graphql.DynamicProperties
1616
import org.neo4j.graphql.SchemaBuilder
1717
import org.neo4j.graphql.SchemaConfig
18+
import org.opentest4j.AssertionFailedError
1819
import java.io.File
1920
import java.util.regex.Pattern
2021
import javax.ws.rs.core.UriBuilder
@@ -62,31 +63,33 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite() {
6263
private val ignore: Boolean) {
6364

6465
fun run() {
65-
println(title)
66+
var augmentedSchema: GraphQLSchema? = null
67+
var expectedSchema: GraphQLSchema? = null
6668
try {
67-
val augmentedSchema = SchemaBuilder.buildSchema(suite.schema, config)
69+
augmentedSchema = SchemaBuilder.buildSchema(suite.schema, config)
6870
val schemaParser = SchemaParser()
6971

70-
println("Augmented Schema:")
71-
println(suite.schemaPrinter.print(augmentedSchema))
72-
7372
val reg = schemaParser.parse(targetSchema)
7473
val schemaGenerator = SchemaGenerator()
7574
val runtimeWiring = RuntimeWiring.newRuntimeWiring()
7675
reg
7776
.getTypes(InterfaceTypeDefinition::class.java)
7877
.forEach { typeDefinition -> runtimeWiring.type(typeDefinition.name) { it.typeResolver { null } } }
79-
val expected = schemaGenerator.makeExecutableSchema(reg, runtimeWiring
78+
expectedSchema = schemaGenerator.makeExecutableSchema(reg, runtimeWiring
8079
.scalar(DynamicProperties.INSTANCE)
8180
.build())
8281

83-
diff(expected, augmentedSchema)
84-
diff(augmentedSchema, expected)
82+
diff(expectedSchema, augmentedSchema)
83+
diff(augmentedSchema, expectedSchema)
8584
} catch (e: Throwable) {
8685
if (ignore) {
8786
Assumptions.assumeFalse(true, e.message)
8887
} else {
89-
throw e
88+
throw AssertionFailedError("augmented schema differs for '$title'",
89+
expectedSchema?.let { suite.schemaPrinter.print(it) } ?: targetSchema,
90+
suite.schemaPrinter.print(augmentedSchema),
91+
e)
92+
9093
}
9194
}
9295
}
@@ -139,4 +142,4 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite() {
139142
}
140143
}
141144
}
142-
}
145+
}

0 commit comments

Comments
 (0)