@@ -4,19 +4,18 @@ use quote::{quote, ToTokens};
4
4
use syn:: {
5
5
parenthesized, parse:: Parse , parse2, parse_macro_input, parse_quote, punctuated:: Punctuated ,
6
6
spanned:: Spanned , token:: Paren , Attribute , Expr , FnArg , Ident , ItemFn , Pat , PatIdent , Path ,
7
- ReturnType , Signature , Stmt , Token , Type , TypePath ,
7
+ PathSegment , ReturnType , Signature , Stmt , Token , Type , TypePath ,
8
8
} ;
9
9
10
10
pub ( crate ) fn r#impl ( _attr : TokenStream , item : TokenStream ) -> TokenStream {
11
11
let mut fn_decl = parse_macro_input ! ( item as ItemFn ) ;
12
12
13
13
let loader = Loader :: from_item_fn ( & mut fn_decl) ;
14
14
15
+ let main_fn = MainFn :: from_item_fn ( & mut fn_decl) ;
16
+
15
17
let expanded = quote ! {
16
- #[ tokio:: main]
17
- async fn main( ) {
18
- shuttle_runtime:: start( loader) . await ;
19
- }
18
+ #main_fn
20
19
21
20
#loader
22
21
@@ -30,6 +29,11 @@ struct Loader {
30
29
fn_ident : Ident ,
31
30
fn_inputs : Vec < Input > ,
32
31
fn_return : TypePath ,
32
+ import_path : PathSegment ,
33
+ }
34
+
35
+ struct MainFn {
36
+ import_path : PathSegment ,
33
37
}
34
38
35
39
#[ derive( Debug , PartialEq ) ]
@@ -128,17 +132,38 @@ impl Loader {
128
132
. collect ( ) ;
129
133
130
134
if let Some ( type_path) = check_return_type ( item_fn. sig . clone ( ) ) {
135
+ // We need the first segment of the path so we can import the codegen dependencies from it.
136
+ let Some ( import_path) = type_path. path . segments . first ( ) . cloned ( ) else {
137
+ return None ;
138
+ } ;
139
+
131
140
Some ( Self {
132
141
fn_ident : item_fn. sig . ident . clone ( ) ,
133
142
fn_inputs : inputs,
134
143
fn_return : type_path,
144
+ import_path,
135
145
} )
136
146
} else {
137
147
None
138
148
}
139
149
}
140
150
}
141
151
152
+ impl MainFn {
153
+ pub ( crate ) fn from_item_fn ( item_fn : & mut ItemFn ) -> Option < Self > {
154
+ if let Some ( type_path) = check_return_type ( item_fn. sig . clone ( ) ) {
155
+ // We need the first segment of the path so we can import the codegen dependencies from it.
156
+ let Some ( import_path) = type_path. path . segments . first ( ) . cloned ( ) else {
157
+ return None ;
158
+ } ;
159
+
160
+ Some ( Self { import_path } )
161
+ } else {
162
+ None
163
+ }
164
+ }
165
+ }
166
+
142
167
fn check_return_type ( signature : Signature ) -> Option < TypePath > {
143
168
match signature. output {
144
169
ReturnType :: Default => {
@@ -193,6 +218,8 @@ impl ToTokens for Loader {
193
218
194
219
let return_type = & self . fn_return ;
195
220
221
+ let import_path = & self . import_path ;
222
+
196
223
let mut fn_inputs: Vec < _ > = Vec :: with_capacity ( self . fn_inputs . len ( ) ) ;
197
224
let mut fn_inputs_builder: Vec < _ > = Vec :: with_capacity ( self . fn_inputs . len ( ) ) ;
198
225
let mut fn_inputs_builder_options: Vec < _ > = Vec :: with_capacity ( self . fn_inputs . len ( ) ) ;
@@ -213,25 +240,26 @@ impl ToTokens for Loader {
213
240
None
214
241
} else {
215
242
Some ( parse_quote ! (
216
- use shuttle_service :: ResourceBuilder ;
243
+ use #import_path :: shuttle_runtime :: ResourceBuilder ;
217
244
) )
218
245
} ;
219
246
247
+ // let import_lib: TypePath = parse_quote!( shuttle_<#fn_ident>);
220
248
let loader = quote ! {
221
- async fn loader<S : shuttle_runtime:: StorageManager >(
222
- mut #factory_ident: shuttle_runtime:: ProvisionerFactory <S >,
223
- logger: shuttle_runtime:: Logger ,
249
+ async fn loader<S : #import_path :: shuttle_runtime:: StorageManager >(
250
+ mut #factory_ident: #import_path :: shuttle_runtime:: ProvisionerFactory <S >,
251
+ logger: #import_path :: shuttle_runtime:: Logger ,
224
252
) -> #return_type {
225
- use shuttle_service :: Context ;
226
- use shuttle_service :: tracing_subscriber:: prelude:: * ;
253
+ use #import_path :: shuttle_runtime :: Context ;
254
+ use #import_path :: shuttle_runtime :: tracing_subscriber:: prelude:: * ;
227
255
#extra_imports
228
256
229
257
let filter_layer =
230
- shuttle_service :: tracing_subscriber:: EnvFilter :: try_from_default_env( )
231
- . or_else( |_| shuttle_service :: tracing_subscriber:: EnvFilter :: try_new( "INFO" ) )
258
+ #import_path :: shuttle_runtime :: tracing_subscriber:: EnvFilter :: try_from_default_env( )
259
+ . or_else( |_| #import_path :: shuttle_runtime :: tracing_subscriber:: EnvFilter :: try_new( "INFO" ) )
232
260
. unwrap( ) ;
233
261
234
- shuttle_service :: tracing_subscriber:: registry( )
262
+ #import_path :: shuttle_runtime :: tracing_subscriber:: registry( )
235
263
. with( filter_layer)
236
264
. with( logger)
237
265
. init( ) ;
@@ -246,11 +274,28 @@ impl ToTokens for Loader {
246
274
}
247
275
}
248
276
277
+ impl ToTokens for MainFn {
278
+ fn to_tokens ( & self , tokens : & mut proc_macro2:: TokenStream ) {
279
+ let import_path = & self . import_path ;
280
+
281
+ // let import_lib: TypePath = parse_quote!( shuttle_<#fn_ident>);
282
+ let main_fn = quote ! {
283
+ #[ tokio:: main]
284
+ async fn main( ) {
285
+ #import_path:: shuttle_runtime:: start( loader) . await ;
286
+ }
287
+
288
+ } ;
289
+
290
+ main_fn. to_tokens ( tokens) ;
291
+ }
292
+ }
293
+
249
294
#[ cfg( test) ]
250
295
mod tests {
251
296
use pretty_assertions:: assert_eq;
252
297
use quote:: quote;
253
- use syn:: { parse_quote, Ident } ;
298
+ use syn:: { parse_quote, Ident , PathSegment } ;
254
299
255
300
use super :: { Builder , BuilderOptions , Input , Loader } ;
256
301
@@ -269,27 +314,30 @@ mod tests {
269
314
270
315
#[ test]
271
316
fn output_with_return ( ) {
317
+ let import_path: PathSegment = parse_quote ! ( shuttle_simple) ;
318
+
272
319
let input = Loader {
273
320
fn_ident : parse_quote ! ( simple) ,
274
321
fn_inputs : Vec :: new ( ) ,
275
322
fn_return : parse_quote ! ( ShuttleSimple ) ,
323
+ import_path,
276
324
} ;
277
325
278
326
let actual = quote ! ( #input) ;
279
327
let expected = quote ! {
280
- async fn loader<S : shuttle_runtime:: StorageManager >(
281
- mut _factory: shuttle_runtime:: ProvisionerFactory <S >,
282
- logger: shuttle_runtime:: Logger ,
328
+ async fn loader<S : shuttle_simple :: shuttle_runtime:: StorageManager >(
329
+ mut _factory: shuttle_simple :: shuttle_runtime:: ProvisionerFactory <S >,
330
+ logger: shuttle_simple :: shuttle_runtime:: Logger ,
283
331
) -> ShuttleSimple {
284
- use shuttle_service :: Context ;
285
- use shuttle_service :: tracing_subscriber:: prelude:: * ;
332
+ use shuttle_simple :: shuttle_runtime :: Context ;
333
+ use shuttle_simple :: shuttle_runtime :: tracing_subscriber:: prelude:: * ;
286
334
287
335
let filter_layer =
288
- shuttle_service :: tracing_subscriber:: EnvFilter :: try_from_default_env( )
289
- . or_else( |_| shuttle_service :: tracing_subscriber:: EnvFilter :: try_new( "INFO" ) )
336
+ shuttle_simple :: shuttle_runtime :: tracing_subscriber:: EnvFilter :: try_from_default_env( )
337
+ . or_else( |_| shuttle_simple :: shuttle_runtime :: tracing_subscriber:: EnvFilter :: try_new( "INFO" ) )
290
338
. unwrap( ) ;
291
339
292
- shuttle_service :: tracing_subscriber:: registry( )
340
+ shuttle_simple :: shuttle_runtime :: tracing_subscriber:: registry( )
293
341
. with( filter_layer)
294
342
. with( logger)
295
343
. init( ) ;
@@ -334,6 +382,8 @@ mod tests {
334
382
335
383
#[ test]
336
384
fn output_with_inputs ( ) {
385
+ let import_path: PathSegment = parse_quote ! ( shuttle_complex) ;
386
+
337
387
let input = Loader {
338
388
fn_ident : parse_quote ! ( complex) ,
339
389
fn_inputs : vec ! [
@@ -353,24 +403,25 @@ mod tests {
353
403
} ,
354
404
] ,
355
405
fn_return : parse_quote ! ( ShuttleComplex ) ,
406
+ import_path,
356
407
} ;
357
408
358
409
let actual = quote ! ( #input) ;
359
410
let expected = quote ! {
360
- async fn loader<S : shuttle_runtime:: StorageManager >(
361
- mut factory: shuttle_runtime:: ProvisionerFactory <S >,
362
- logger: shuttle_runtime:: Logger ,
411
+ async fn loader<S : shuttle_complex :: shuttle_runtime:: StorageManager >(
412
+ mut factory: shuttle_complex :: shuttle_runtime:: ProvisionerFactory <S >,
413
+ logger: shuttle_complex :: shuttle_runtime:: Logger ,
363
414
) -> ShuttleComplex {
364
- use shuttle_service :: Context ;
365
- use shuttle_service :: tracing_subscriber:: prelude:: * ;
366
- use shuttle_service :: ResourceBuilder ;
415
+ use shuttle_complex :: shuttle_runtime :: Context ;
416
+ use shuttle_complex :: shuttle_runtime :: tracing_subscriber:: prelude:: * ;
417
+ use shuttle_complex :: shuttle_runtime :: ResourceBuilder ;
367
418
368
419
let filter_layer =
369
- shuttle_service :: tracing_subscriber:: EnvFilter :: try_from_default_env( )
370
- . or_else( |_| shuttle_service :: tracing_subscriber:: EnvFilter :: try_new( "INFO" ) )
420
+ shuttle_complex :: shuttle_runtime :: tracing_subscriber:: EnvFilter :: try_from_default_env( )
421
+ . or_else( |_| shuttle_complex :: shuttle_runtime :: tracing_subscriber:: EnvFilter :: try_new( "INFO" ) )
371
422
. unwrap( ) ;
372
423
373
- shuttle_service :: tracing_subscriber:: registry( )
424
+ shuttle_complex :: shuttle_runtime :: tracing_subscriber:: registry( )
374
425
. with( filter_layer)
375
426
. with( logger)
376
427
. init( ) ;
@@ -460,6 +511,8 @@ mod tests {
460
511
461
512
#[ test]
462
513
fn output_with_input_options ( ) {
514
+ let import_path: PathSegment = parse_quote ! ( shuttle_complex) ;
515
+
463
516
let mut input = Loader {
464
517
fn_ident : parse_quote ! ( complex) ,
465
518
fn_inputs : vec ! [ Input {
@@ -470,6 +523,7 @@ mod tests {
470
523
} ,
471
524
} ] ,
472
525
fn_return : parse_quote ! ( ShuttleComplex ) ,
526
+ import_path,
473
527
} ;
474
528
475
529
input. fn_inputs [ 0 ]
@@ -485,20 +539,20 @@ mod tests {
485
539
486
540
let actual = quote ! ( #input) ;
487
541
let expected = quote ! {
488
- async fn loader<S : shuttle_runtime:: StorageManager >(
489
- mut factory: shuttle_runtime:: ProvisionerFactory <S >,
490
- logger: shuttle_runtime:: Logger ,
542
+ async fn loader<S : shuttle_complex :: shuttle_runtime:: StorageManager >(
543
+ mut factory: shuttle_complex :: shuttle_runtime:: ProvisionerFactory <S >,
544
+ logger: shuttle_complex :: shuttle_runtime:: Logger ,
491
545
) -> ShuttleComplex {
492
- use shuttle_service :: Context ;
493
- use shuttle_service :: tracing_subscriber:: prelude:: * ;
494
- use shuttle_service :: ResourceBuilder ;
546
+ use shuttle_complex :: shuttle_runtime :: Context ;
547
+ use shuttle_complex :: shuttle_runtime :: tracing_subscriber:: prelude:: * ;
548
+ use shuttle_complex :: shuttle_runtime :: ResourceBuilder ;
495
549
496
550
let filter_layer =
497
- shuttle_service :: tracing_subscriber:: EnvFilter :: try_from_default_env( )
498
- . or_else( |_| shuttle_service :: tracing_subscriber:: EnvFilter :: try_new( "INFO" ) )
551
+ shuttle_complex :: shuttle_runtime :: tracing_subscriber:: EnvFilter :: try_from_default_env( )
552
+ . or_else( |_| shuttle_complex :: shuttle_runtime :: tracing_subscriber:: EnvFilter :: try_new( "INFO" ) )
499
553
. unwrap( ) ;
500
554
501
- shuttle_service :: tracing_subscriber:: registry( )
555
+ shuttle_complex :: shuttle_runtime :: tracing_subscriber:: registry( )
502
556
. with( filter_layer)
503
557
. with( logger)
504
558
. init( ) ;
0 commit comments