@@ -30,8 +30,6 @@ fn expand_enum(ast: syn::DeriveInput, mode: Mode) -> syn::Result<UnionDefinition
3030 let enum_span = ast. span ( ) ;
3131 let enum_ident = ast. ident ;
3232
33- // TODO: validate type has no generics
34-
3533 let name = meta
3634 . name
3735 . clone ( )
@@ -45,13 +43,46 @@ fn expand_enum(ast: syn::DeriveInput, mode: Mode) -> syn::Result<UnionDefinition
4543 ) ;
4644 }
4745
48- let variants: Vec < _ > = match ast. data {
46+ let mut variants: Vec < _ > = match ast. data {
4947 Data :: Enum ( data) => data. variants ,
5048 _ => unreachable ! ( ) ,
5149 }
5250 . into_iter ( )
5351 . filter_map ( |var| graphql_union_variant_from_enum_variant ( var, & enum_ident) )
5452 . collect ( ) ;
53+ if !meta. custom_resolvers . is_empty ( ) {
54+ let crate_path = mode. crate_path ( ) ;
55+ // TODO: refactor into separate function
56+ for ( ty, rslvr) in meta. custom_resolvers {
57+ let span = rslvr. span_joined ( ) ;
58+
59+ let resolver_fn = rslvr. into_inner ( ) ;
60+ let resolver_code = parse_quote ! {
61+ #resolver_fn( self , #crate_path:: FromContext :: from( context) )
62+ } ;
63+ // Doing this may be quite an expensive, because resolving may contain some heavy
64+ // computation, so we're preforming it twice. Unfortunately, we have no other options
65+ // here, until the `juniper::GraphQLType` itself will allow to do it in some cleverer
66+ // way.
67+ let resolver_check = parse_quote ! {
68+ ( { #resolver_code } as :: std:: option:: Option <& #ty>) . is_some( )
69+ } ;
70+
71+ if let Some ( var) = variants. iter_mut ( ) . find ( |v| v. ty == ty) {
72+ var. resolver_code = resolver_code;
73+ var. resolver_check = resolver_check;
74+ var. span = span;
75+ } else {
76+ variants. push ( UnionVariantDefinition {
77+ ty,
78+ resolver_code,
79+ resolver_check,
80+ enum_path : None ,
81+ span,
82+ } )
83+ }
84+ }
85+ }
5586 if variants. is_empty ( ) {
5687 SCOPE . not_empty ( enum_span) ;
5788 }
@@ -97,7 +128,19 @@ fn graphql_union_variant_from_enum_variant(
97128
98129 let var_span = var. span ( ) ;
99130 let var_ident = var. ident ;
100- let path = quote ! { #enum_ident:: #var_ident } ;
131+ let enum_path = quote ! { #enum_ident:: #var_ident } ;
132+
133+ // TODO
134+ if meta. custom_resolver . is_some ( ) {
135+ unimplemented ! ( )
136+ }
137+
138+ let resolver_code = parse_quote ! {
139+ match self { #enum_ident:: #var_ident( ref v) => Some ( v) , _ => None , }
140+ } ;
141+ let resolver_check = parse_quote ! {
142+ matches!( self , #enum_path( _) )
143+ } ;
101144
102145 let ty = match var. fields {
103146 Fields :: Unnamed ( fields) => {
@@ -121,14 +164,18 @@ fn graphql_union_variant_from_enum_variant(
121164
122165 Some ( UnionVariantDefinition {
123166 ty,
124- path,
167+ resolver_code,
168+ resolver_check,
169+ enum_path : Some ( enum_path) ,
125170 span : var_span,
126171 } )
127172}
128173
129174struct UnionVariantDefinition {
130175 pub ty : syn:: Type ,
131- pub path : TokenStream ,
176+ pub resolver_code : syn:: Expr ,
177+ pub resolver_check : syn:: Expr ,
178+ pub enum_path : Option < TokenStream > ,
132179 pub span : Span ,
133180}
134181
@@ -177,23 +224,16 @@ impl UnionDefinition {
177224
178225 let match_names = self . variants . iter ( ) . map ( |var| {
179226 let var_ty = & var. ty ;
180- let var_path = & var. path ;
227+ let var_check = & var. resolver_check ;
181228 quote ! {
182- #var_path( _) => <#var_ty as #crate_path:: GraphQLType <#scalar>>:: name( & ( ) )
183- . unwrap( ) . to_string( ) ,
229+ if #var_check {
230+ return <#var_ty as #crate_path:: GraphQLType <#scalar>>:: name( & ( ) )
231+ . unwrap( ) . to_string( ) ;
232+ }
184233 }
185234 } ) ;
186235
187- let match_resolves: Vec < _ > = self
188- . variants
189- . iter ( )
190- . map ( |var| {
191- let var_path = & var. path ;
192- quote ! {
193- match self { #var_path( ref val) => Some ( val) , _ => None , }
194- }
195- } )
196- . collect ( ) ;
236+ let match_resolves: Vec < _ > = self . variants . iter ( ) . map ( |var| & var. resolver_code ) . collect ( ) ;
197237 let resolve_into_type = self . variants . iter ( ) . zip ( match_resolves. iter ( ) ) . map ( |( var, expr) | {
198238 let var_ty = & var. ty ;
199239
@@ -291,12 +331,15 @@ impl UnionDefinition {
291331
292332 fn concrete_type_name(
293333 & self ,
294- _ : & Self :: Context ,
334+ context : & Self :: Context ,
295335 _: & Self :: TypeInfo ,
296336 ) -> String {
297- match self {
298- #( #match_names ) *
299- }
337+ #( #match_names ) *
338+ panic!(
339+ "GraphQL union {} cannot be resolved into any of its variants in its \
340+ current state",
341+ #name,
342+ ) ;
300343 }
301344
302345 fn resolve_into_type(
@@ -306,9 +349,10 @@ impl UnionDefinition {
306349 _: Option <& [ #crate_path:: Selection <#scalar>] >,
307350 executor: & #crate_path:: Executor <Self :: Context , #scalar>,
308351 ) -> #crate_path:: ExecutionResult <#scalar> {
352+ let context = executor. context( ) ;
309353 #( #resolve_into_type ) *
310354 panic!(
311- "Concrete type {} is not handled by instance resolvers on GraphQL Union {}" ,
355+ "Concrete type {} is not handled by instance resolvers on GraphQL union {}" ,
312356 type_name, #name,
313357 ) ;
314358 }
@@ -327,26 +371,27 @@ impl UnionDefinition {
327371 _: Option <& ' b [ #crate_path:: Selection <' b, #scalar>] >,
328372 executor: & ' b #crate_path:: Executor <' b, ' b, Self :: Context , #scalar>
329373 ) -> #crate_path:: BoxFuture <' b, #crate_path:: ExecutionResult <#scalar>> {
374+ let context = executor. context( ) ;
330375 #( #resolve_into_type_async ) *
331376 panic!(
332- "Concrete type {} is not handled by instance resolvers on GraphQL Union {}" ,
377+ "Concrete type {} is not handled by instance resolvers on GraphQL union {}" ,
333378 type_name, #name,
334379 ) ;
335380 }
336381 }
337382 } ;
338383
339- let conversion_impls = self . variants . iter ( ) . map ( |var| {
384+ let conversion_impls = self . variants . iter ( ) . filter_map ( |var| {
340385 let var_ty = & var. ty ;
341- let var_path = & var. path ;
342- quote ! {
386+ let var_path = var. enum_path . as_ref ( ) ? ;
387+ Some ( quote ! {
343388 #[ automatically_derived]
344389 impl #impl_generics :: std:: convert:: From <#var_ty> for #ty#ty_generics {
345390 fn from( v: #var_ty) -> Self {
346391 #var_path( v)
347392 }
348393 }
349- }
394+ } )
350395 } ) ;
351396
352397 let output_type_impl = quote ! {
0 commit comments