Skip to content

Commit 9e5e3e0

Browse files
authored
Merge pull request #1 from introproventures/master
feat: add JPA @EmbeddedId support (#84)
2 parents ac75192 + 0def68d commit 9e5e3e0

File tree

5 files changed

+206
-109
lines changed

5 files changed

+206
-109
lines changed

graphql-jpa-query-schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/GraphQLJpaSchemaBuilder.java

Lines changed: 133 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,16 @@
4242
import javax.persistence.metamodel.SingularAttribute;
4343
import javax.persistence.metamodel.Type;
4444

45+
import org.slf4j.Logger;
46+
import org.slf4j.LoggerFactory;
47+
4548
import com.introproventures.graphql.jpa.query.annotation.GraphQLDescription;
4649
import com.introproventures.graphql.jpa.query.annotation.GraphQLIgnore;
4750
import com.introproventures.graphql.jpa.query.schema.GraphQLSchemaBuilder;
4851
import com.introproventures.graphql.jpa.query.schema.JavaScalars;
4952
import com.introproventures.graphql.jpa.query.schema.NamingStrategy;
5053
import com.introproventures.graphql.jpa.query.schema.impl.PredicateFilter.Criteria;
54+
5155
import graphql.Assert;
5256
import graphql.Scalars;
5357
import graphql.schema.Coercing;
@@ -65,8 +69,6 @@
6569
import graphql.schema.GraphQLType;
6670
import graphql.schema.GraphQLTypeReference;
6771
import graphql.schema.PropertyDataFetcher;
68-
import org.slf4j.Logger;
69-
import org.slf4j.LoggerFactory;
7072

7173
/**
7274
* JPA specific schema builder implementation of {code #GraphQLSchemaBuilder} interface
@@ -95,7 +97,8 @@ public class GraphQLJpaSchemaBuilder implements GraphQLSchemaBuilder {
9597

9698
private Map<Class<?>, GraphQLType> classCache = new HashMap<>();
9799
private Map<EntityType<?>, GraphQLObjectType> entityCache = new HashMap<>();
98-
private Map<EmbeddableType<?>, GraphQLObjectType> embeddableCache = new HashMap<>();
100+
private Map<EmbeddableType<?>, GraphQLObjectType> embeddableOutputCache = new HashMap<>();
101+
private Map<EmbeddableType<?>, GraphQLInputObjectType> embeddableInputCache = new HashMap<>();
99102

100103
private static final Logger log = LoggerFactory.getLogger(GraphQLJpaSchemaBuilder.class);
101104

@@ -292,13 +295,13 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
292295
.field(GraphQLInputObjectField.newInputObjectField()
293296
.name(Criteria.EQ.name())
294297
.description("Equals criteria")
295-
.type((GraphQLInputType) getAttributeType(attribute))
298+
.type(getAttributeInputType(attribute))
296299
.build()
297300
)
298301
.field(GraphQLInputObjectField.newInputObjectField()
299302
.name(Criteria.NE.name())
300303
.description("Not Equals criteria")
301-
.type((GraphQLInputType) getAttributeType(attribute))
304+
.type(getAttributeInputType(attribute))
302305
.build()
303306
);
304307

@@ -307,25 +310,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
307310
builder.field(GraphQLInputObjectField.newInputObjectField()
308311
.name(Criteria.LE.name())
309312
.description("Less then or Equals criteria")
310-
.type((GraphQLInputType) getAttributeType(attribute))
313+
.type(getAttributeInputType(attribute))
311314
.build()
312315
)
313316
.field(GraphQLInputObjectField.newInputObjectField()
314317
.name(Criteria.GE.name())
315318
.description("Greater or Equals criteria")
316-
.type((GraphQLInputType) getAttributeType(attribute))
319+
.type(getAttributeInputType(attribute))
317320
.build()
318321
)
319322
.field(GraphQLInputObjectField.newInputObjectField()
320323
.name(Criteria.GT.name())
321324
.description("Greater Then criteria")
322-
.type((GraphQLInputType) getAttributeType(attribute))
325+
.type(getAttributeInputType(attribute))
323326
.build()
324327
)
325328
.field(GraphQLInputObjectField.newInputObjectField()
326329
.name(Criteria.LT.name())
327330
.description("Less Then criteria")
328-
.type((GraphQLInputType) getAttributeType(attribute))
331+
.type(getAttributeInputType(attribute))
329332
.build()
330333
);
331334
}
@@ -334,25 +337,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
334337
builder.field(GraphQLInputObjectField.newInputObjectField()
335338
.name(Criteria.LIKE.name())
336339
.description("Like criteria")
337-
.type((GraphQLInputType) getAttributeType(attribute))
340+
.type(getAttributeInputType(attribute))
338341
.build()
339342
)
340343
.field(GraphQLInputObjectField.newInputObjectField()
341344
.name(Criteria.CASE.name())
342345
.description("Case sensitive match criteria")
343-
.type((GraphQLInputType) getAttributeType(attribute))
346+
.type(getAttributeInputType(attribute))
344347
.build()
345348
)
346349
.field(GraphQLInputObjectField.newInputObjectField()
347350
.name(Criteria.STARTS.name())
348351
.description("Starts with criteria")
349-
.type((GraphQLInputType) getAttributeType(attribute))
352+
.type(getAttributeInputType(attribute))
350353
.build()
351354
)
352355
.field(GraphQLInputObjectField.newInputObjectField()
353356
.name(Criteria.ENDS.name())
354357
.description("Ends with criteria")
355-
.type((GraphQLInputType) getAttributeType(attribute))
358+
.type(getAttributeInputType(attribute))
356359
.build()
357360
);
358361
}
@@ -373,25 +376,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
373376
.field(GraphQLInputObjectField.newInputObjectField()
374377
.name(Criteria.IN.name())
375378
.description("In criteria")
376-
.type(new GraphQLList(getAttributeType(attribute)))
379+
.type(new GraphQLList(getAttributeInputType(attribute)))
377380
.build()
378381
)
379382
.field(GraphQLInputObjectField.newInputObjectField()
380383
.name(Criteria.NIN.name())
381384
.description("Not In criteria")
382-
.type(new GraphQLList(getAttributeType(attribute)))
385+
.type(new GraphQLList(getAttributeInputType(attribute)))
383386
.build()
384387
)
385388
.field(GraphQLInputObjectField.newInputObjectField()
386389
.name(Criteria.BETWEEN.name())
387390
.description("Between criteria")
388-
.type(new GraphQLList(getAttributeType(attribute)))
391+
.type(new GraphQLList(getAttributeInputType(attribute)))
389392
.build()
390393
)
391394
.field(GraphQLInputObjectField.newInputObjectField()
392395
.name(Criteria.NOT_BETWEEN.name())
393396
.description("Not Between criteria")
394-
.type(new GraphQLList(getAttributeType(attribute)))
397+
.type(new GraphQLList(getAttributeInputType(attribute)))
395398
.build()
396399
);
397400

@@ -404,39 +407,52 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
404407
}
405408

406409
private GraphQLArgument getArgument(Attribute<?,?> attribute) {
407-
GraphQLType type = getAttributeType(attribute);
410+
GraphQLInputType type = getAttributeInputType(attribute);
408411
String description = getSchemaDescription(attribute.getJavaMember());
409412

410-
if (type instanceof GraphQLInputType) {
411-
return GraphQLArgument.newArgument()
412-
.name(attribute.getName())
413-
.type((GraphQLInputType) type)
414-
.description(description)
415-
.build();
416-
}
417-
418-
throw new IllegalArgumentException("Attribute " + attribute + " cannot be mapped as an Input Argument");
413+
return GraphQLArgument.newArgument()
414+
.name(attribute.getName())
415+
.type((GraphQLInputType) type)
416+
.description(description)
417+
.build();
419418
}
420419

421-
private GraphQLObjectType getEmbeddableType(EmbeddableType<?> embeddableType) {
422-
if (embeddableCache.containsKey(embeddableType))
423-
return embeddableCache.get(embeddableType);
424-
425-
String embeddableTypeName = namingStrategy.singularize(embeddableType.getJavaType().getSimpleName())+"EmbeddableType";
426-
427-
GraphQLObjectType objectType = GraphQLObjectType.newObject()
428-
.name(embeddableTypeName)
429-
.description(getSchemaDescription( embeddableType.getJavaType()))
430-
.fields(embeddableType.getAttributes().stream()
431-
.filter(this::isNotIgnored)
432-
.map(this::getObjectField)
433-
.collect(Collectors.toList())
434-
)
435-
.build();
436-
437-
embeddableCache.putIfAbsent(embeddableType, objectType);
420+
private GraphQLType getEmbeddableType(EmbeddableType<?> embeddableType, boolean input) {
421+
if (input && embeddableInputCache.containsKey(embeddableType))
422+
return embeddableInputCache.get(embeddableType);
423+
424+
if (!input && embeddableOutputCache.containsKey(embeddableType))
425+
return embeddableOutputCache.get(embeddableType);
426+
String embeddableTypeName = namingStrategy.singularize(embeddableType.getJavaType().getSimpleName())+ (input ? "Input" : "") +"EmbeddableType";
427+
GraphQLType graphQLType=null;
428+
if (input) {
429+
graphQLType = GraphQLInputObjectType.newInputObject()
430+
.name(embeddableTypeName)
431+
.description(getSchemaDescription(embeddableType.getJavaType()))
432+
.fields(embeddableType.getAttributes().stream()
433+
.filter(this::isNotIgnored)
434+
.map(this::getInputObjectField)
435+
.collect(Collectors.toList())
436+
)
437+
.build();
438+
} else {
439+
graphQLType = GraphQLObjectType.newObject()
440+
.name(embeddableTypeName)
441+
.description(getSchemaDescription(embeddableType.getJavaType()))
442+
.fields(embeddableType.getAttributes().stream()
443+
.filter(this::isNotIgnored)
444+
.map(this::getObjectField)
445+
.collect(Collectors.toList())
446+
)
447+
.build();
448+
}
449+
if (input) {
450+
embeddableInputCache.putIfAbsent(embeddableType, (GraphQLInputObjectType) graphQLType);
451+
} else{
452+
embeddableOutputCache.putIfAbsent(embeddableType, (GraphQLObjectType) graphQLType);
453+
}
438454

439-
return objectType;
455+
return graphQLType;
440456
}
441457

442458

@@ -462,67 +478,92 @@ private GraphQLObjectType getObjectType(EntityType<?> entityType) {
462478

463479
@SuppressWarnings( { "rawtypes", "unchecked" } )
464480
private GraphQLFieldDefinition getObjectField(Attribute attribute) {
465-
GraphQLType type = getAttributeType(attribute);
466-
467-
if (type instanceof GraphQLOutputType) {
468-
List<GraphQLArgument> arguments = new ArrayList<>();
469-
DataFetcher dataFetcher = PropertyDataFetcher.fetching(attribute.getName());
470-
471-
// Only add the orderBy argument for basic attribute types
472-
if (attribute instanceof SingularAttribute
473-
&& attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC) {
474-
arguments.add(GraphQLArgument.newArgument()
475-
.name(ORDER_BY_PARAM_NAME)
476-
.description("Specifies field sort direction in the query results.")
477-
.type(orderByDirectionEnum)
478-
.build()
479-
);
480-
}
481-
482-
// Get the fields that can be queried on (i.e. Simple Types, no Sub-Objects)
483-
if (attribute instanceof SingularAttribute
484-
&& attribute.getPersistentAttributeType() != Attribute.PersistentAttributeType.BASIC) {
485-
ManagedType foreignType = (ManagedType) ((SingularAttribute) attribute).getType();
486-
487-
// TODO fix page count query
488-
arguments.add(getWhereArgument(foreignType));
489-
490-
} // Get Sub-Objects fields queries via DataFetcher
491-
else if (attribute instanceof PluralAttribute
492-
&& (attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.ONE_TO_MANY
493-
|| attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.MANY_TO_MANY)) {
494-
EntityType declaringType = (EntityType) ((PluralAttribute) attribute).getDeclaringType();
495-
EntityType elementType = (EntityType) ((PluralAttribute) attribute).getElementType();
496-
497-
arguments.add(getWhereArgument(elementType));
498-
dataFetcher = new GraphQLJpaOneToManyDataFetcher(entityManager, declaringType, (PluralAttribute) attribute);
499-
}
481+
GraphQLOutputType type = getAttributeOutputType(attribute);
482+
483+
List<GraphQLArgument> arguments = new ArrayList<>();
484+
DataFetcher dataFetcher = PropertyDataFetcher.fetching(attribute.getName());
485+
486+
// Only add the orderBy argument for basic attribute types
487+
if (attribute instanceof SingularAttribute
488+
&& attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC) {
489+
arguments.add(GraphQLArgument.newArgument()
490+
.name(ORDER_BY_PARAM_NAME)
491+
.description("Specifies field sort direction in the query results.")
492+
.type(orderByDirectionEnum)
493+
.build()
494+
);
495+
}
500496

501-
return GraphQLFieldDefinition.newFieldDefinition()
502-
.name(attribute.getName())
503-
.description(getSchemaDescription(attribute.getJavaMember()))
504-
.type((GraphQLOutputType) type)
505-
.dataFetcher(dataFetcher)
506-
.argument(arguments)
507-
.build();
497+
// Get the fields that can be queried on (i.e. Simple Types, no Sub-Objects)
498+
if (attribute instanceof SingularAttribute
499+
&& attribute.getPersistentAttributeType() != Attribute.PersistentAttributeType.BASIC) {
500+
ManagedType foreignType = (ManagedType) ((SingularAttribute) attribute).getType();
501+
502+
// TODO fix page count query
503+
arguments.add(getWhereArgument(foreignType));
504+
505+
} // Get Sub-Objects fields queries via DataFetcher
506+
else if (attribute instanceof PluralAttribute
507+
&& (attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.ONE_TO_MANY
508+
|| attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.MANY_TO_MANY)) {
509+
EntityType declaringType = (EntityType) ((PluralAttribute) attribute).getDeclaringType();
510+
EntityType elementType = (EntityType) ((PluralAttribute) attribute).getElementType();
511+
512+
arguments.add(getWhereArgument(elementType));
513+
dataFetcher = new GraphQLJpaOneToManyDataFetcher(entityManager, declaringType, (PluralAttribute) attribute);
508514
}
509515

510-
throw new IllegalArgumentException("Attribute " + attribute + " cannot be mapped as an Output Argument");
516+
return GraphQLFieldDefinition.newFieldDefinition()
517+
.name(attribute.getName())
518+
.description(getSchemaDescription(attribute.getJavaMember()))
519+
.type(type)
520+
.dataFetcher(dataFetcher)
521+
.argument(arguments)
522+
.build();
523+
}
524+
525+
@SuppressWarnings( { "rawtypes", "unchecked" } )
526+
private GraphQLInputObjectField getInputObjectField(Attribute attribute) {
527+
GraphQLInputType type = getAttributeInputType(attribute);
528+
529+
return GraphQLInputObjectField.newInputObjectField()
530+
.name(attribute.getName())
531+
.description(getSchemaDescription(attribute.getJavaMember()))
532+
.type(type)
533+
.build();
511534
}
512535

513536
private Stream<Attribute<?,?>> findBasicAttributes(Collection<Attribute<?,?>> attributes) {
514537
return attributes.stream().filter(it -> it.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC);
515538
}
516539

517540
@SuppressWarnings( "rawtypes" )
518-
private GraphQLType getAttributeType(Attribute<?,?> attribute) {
541+
private GraphQLInputType getAttributeInputType(Attribute<?,?> attribute) {
542+
try{
543+
return (GraphQLInputType) getAttributeType(attribute, true);
544+
} catch (ClassCastException e){
545+
throw new IllegalArgumentException("Attribute " + attribute + " cannot be mapped as an Input Argument");
546+
}
547+
}
548+
549+
@SuppressWarnings( "rawtypes" )
550+
private GraphQLOutputType getAttributeOutputType(Attribute<?,?> attribute) {
551+
try {
552+
return (GraphQLOutputType) getAttributeType(attribute, false);
553+
} catch (ClassCastException e){
554+
throw new IllegalArgumentException("Attribute " + attribute + " cannot be mapped as an Output Argument");
555+
}
556+
}
557+
558+
@SuppressWarnings( "rawtypes" )
559+
private GraphQLType getAttributeType(Attribute<?,?> attribute, boolean input) {
519560

520561
if (isBasic(attribute)) {
521562
return getGraphQLTypeFromJavaType(attribute.getJavaType());
522563
}
523564
else if (isEmbeddable(attribute)) {
524565
EmbeddableType embeddableType = (EmbeddableType) ((SingularAttribute) attribute).getType();
525-
return getEmbeddableType(embeddableType);
566+
return getEmbeddableType(embeddableType, input);
526567
}
527568
else if (isToMany(attribute)) {
528569
EntityType foreignType = (EntityType) ((PluralAttribute) attribute).getElementType();
@@ -572,7 +613,8 @@ protected final boolean isToOne(Attribute<?,?> attribute) {
572613

573614
protected final boolean isValidInput(Attribute<?,?> attribute) {
574615
return attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC ||
575-
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.ELEMENT_COLLECTION;
616+
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.ELEMENT_COLLECTION ||
617+
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.EMBEDDED;
576618
}
577619

578620
private String getSchemaDescription(Member member) {

0 commit comments

Comments
 (0)