Skip to content

Commit 5161e0b

Browse files
authored
Add dependency injection support for Nexus services (#561)
1 parent 17743ae commit 5161e0b

File tree

6 files changed

+737
-0
lines changed

6 files changed

+737
-0
lines changed

src/Temporalio.Extensions.Hosting/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ For registering workflows on the worker, `AddWorkflow` extension method is avail
8585
collection because the construction and lifecycle of workflows is managed by Temporal. Dependency injection for
8686
workflows is intentionally not supported.
8787

88+
⚠️WARNING: Nexus support is experimental.
89+
90+
For adding Nexus service handlers to the service collection and registering operations with the worker, the following
91+
extensions methods exist on the builder each accepting Nexus service handler type:
92+
93+
* `AddSingletonNexusService` - `TryAddSingleton` + register Nexus service handler on worker
94+
* `AddScopedNexusService` - `TryAddScoped` + register Nexus service handler on worker
95+
* `AddTransientNexusService` - `TryAddTransient` + register Nexus service handler on worker
96+
97+
These all expect the Nexus service handler to have the `NexusServiceHandler` attribute that references the service
98+
interface and the Nexus operation handlers to have the `NexusOperationHandler` attribute.
99+
88100
Other worker and client options can be configured on the builder via the `ConfigureOptions` extension method. With no
89101
parameters, this returns an `OptionsBuilder<TemporalWorkerServiceOptions>` to use. When provided an action, the options
90102
are available as parameters that can be configured. `TemporalWorkerServiceOptions` simply extends
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Reflection;
5+
using NexusRpc;
6+
using NexusRpc.Handlers;
7+
8+
namespace Temporalio.Extensions.Hosting
9+
{
10+
/// <summary>
11+
/// Helper for contructing <see cref="ServiceHandlerInstance"/>.
12+
/// </summary>
13+
/// <remarks>
14+
/// This is internal and should be moved to NexusRpc in the future.
15+
/// </remarks>
16+
internal static class ServiceHandlerInstanceHelper
17+
{
18+
/// <summary>
19+
/// Create a service handler instance from the given service handler type and handler factory.
20+
/// </summary>
21+
/// <param name="serviceHandlerType">The type of the Nexus service handler.</param>
22+
/// <param name="handlerFactory">A factory that converts method information into an operation handler.</param>
23+
/// <returns>A <see cref="ServiceHandlerInstance"/> for the given <paramref name="serviceHandlerType"/> type.</returns>
24+
public static ServiceHandlerInstance FromType(Type serviceHandlerType, Func<MethodInfo, IOperationHandler<object?, object?>> handlerFactory)
25+
{
26+
var serviceDef = GetServiceDefinition(serviceHandlerType);
27+
28+
return new ServiceHandlerInstance(
29+
serviceDef,
30+
CreateHandlers(
31+
serviceDef,
32+
serviceHandlerType,
33+
handlerFactory));
34+
}
35+
36+
/// <summary>
37+
/// Creates a <see cref="ServiceDefinition"/> for the given service handler type.
38+
/// </summary>
39+
/// <param name="serviceHandlerType">The type of the Nexus service handler.</param>
40+
/// <returns>A <see cref="ServiceDefinition"/> for the given <paramref name="serviceHandlerType"/> type.</returns>
41+
private static ServiceDefinition GetServiceDefinition(Type serviceHandlerType)
42+
{
43+
// Make sure the attribute is on the declaring type of the instance
44+
var handlerAttr = serviceHandlerType.GetCustomAttribute<NexusServiceHandlerAttribute>() ??
45+
throw new ArgumentException("Missing NexusServiceHandler attribute");
46+
return ServiceDefinition.FromType(handlerAttr.ServiceType);
47+
}
48+
49+
/// <summary>
50+
/// Collects all public methods from the given type and its base types recursively.
51+
/// </summary>
52+
/// <param name="serviceHandlerType">The type of the Nexus service handler.</param>
53+
/// <param name="methods">The list to which discovered methods are added.</param>
54+
private static void CollectTypeMethods(Type serviceHandlerType, List<MethodInfo> methods)
55+
{
56+
// Add all declared public static/instance methods that do not already have one like
57+
// it present
58+
foreach (var method in serviceHandlerType.GetMethods(
59+
BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static | BindingFlags.DeclaredOnly))
60+
{
61+
// Only add if there isn't already one that matches the base definition
62+
var baseDef = method.GetBaseDefinition();
63+
if (!methods.Any(m => baseDef == m.GetBaseDefinition()))
64+
{
65+
methods.Add(method);
66+
}
67+
}
68+
if (serviceHandlerType.BaseType is { } baseType)
69+
{
70+
CollectTypeMethods(baseType, methods);
71+
}
72+
}
73+
74+
/// <summary>
75+
/// Validates and adds an operation handler created from the given operation handler method.
76+
/// </summary>
77+
/// <param name="serviceDef">A <see cref="ServiceDefinition"/> for the given service handler type.</param>
78+
/// <param name="method">The method from which an operation hander is created.</param>
79+
/// <param name="handlerFactory">A factory that creates an operation handler for a given method.</param>
80+
/// <param name="opHandlers">The mapping of operation names to operation handlers.</param>
81+
private static void AddOperationHandler(
82+
ServiceDefinition serviceDef,
83+
MethodInfo method,
84+
Func<MethodInfo, IOperationHandler<object?, object?>> handlerFactory,
85+
Dictionary<string, IOperationHandler<object?, object?>> opHandlers)
86+
{
87+
// Validate
88+
if (method.GetParameters().Length != 0)
89+
{
90+
throw new ArgumentException("Cannot have parameters");
91+
}
92+
if (method.ContainsGenericParameters)
93+
{
94+
throw new ArgumentException("Cannot be generic");
95+
}
96+
if (!method.IsPublic)
97+
{
98+
throw new ArgumentException("Must be public");
99+
}
100+
101+
// Find definition by the method name
102+
var opDef = serviceDef.Operations.Values.FirstOrDefault(o => o.MethodInfo?.Name == method.Name) ??
103+
throw new ArgumentException("No matching NexusOperation on the service interface");
104+
105+
// Check return
106+
var goodReturn = false;
107+
if (method.ReturnType.IsGenericType &&
108+
method.ReturnType.GetGenericTypeDefinition() == typeof(IOperationHandler<,>))
109+
{
110+
var args = method.ReturnType.GetGenericArguments();
111+
goodReturn = args.Length == 2 &&
112+
NoValue.NormalizeVoidType(args[0]) == opDef.InputType &&
113+
NoValue.NormalizeVoidType(args[1]) == opDef.OutputType;
114+
}
115+
if (!goodReturn)
116+
{
117+
var inType = opDef.InputType == typeof(void) ? typeof(NoValue) : opDef.InputType;
118+
var outType = opDef.OutputType == typeof(void) ? typeof(NoValue) : opDef.OutputType;
119+
throw new ArgumentException(
120+
$"Expected return type of IOperationHandler<{inType.Name}, {outType.Name}>");
121+
}
122+
123+
// Confirm not present already
124+
if (opHandlers.ContainsKey(opDef.Name))
125+
{
126+
throw new ArgumentException($"Duplicate operation handler named ${opDef.Name}");
127+
}
128+
129+
opHandlers[opDef.Name] = handlerFactory(method);
130+
}
131+
132+
/// <summary>
133+
/// Creates a mapping of operation names to operation handlers for the given service handler type.
134+
/// </summary>
135+
/// <param name="serviceDef">A <see cref="ServiceDefinition"/> for the given service handler type.</param>
136+
/// <param name="serviceHandlerType">The type of the Nexus service handler.</param>
137+
/// <param name="handlerFactory">A factory that creates an operation handler for a given method.</param>
138+
/// <returns>A mapping of operation names to operation handlers.</returns>
139+
private static Dictionary<string, IOperationHandler<object?, object?>> CreateHandlers(
140+
ServiceDefinition serviceDef,
141+
Type serviceHandlerType,
142+
Func<MethodInfo, IOperationHandler<object?, object?>> handlerFactory)
143+
{
144+
// Collect all methods recursively
145+
var methods = new List<MethodInfo>();
146+
CollectTypeMethods(serviceHandlerType, methods);
147+
148+
// Collect handlers from the method list
149+
var opHandlers = new Dictionary<string, IOperationHandler<object?, object?>>();
150+
foreach (var method in methods)
151+
{
152+
// Only care about ones with operation attribute
153+
if (method.GetCustomAttribute<NexusOperationHandlerAttribute>() == null)
154+
{
155+
continue;
156+
}
157+
158+
try
159+
{
160+
AddOperationHandler(serviceDef, method, handlerFactory, opHandlers);
161+
}
162+
catch (Exception e)
163+
{
164+
throw new ArgumentException(
165+
$"Failed obtaining operation handler from {method.Name}", e);
166+
}
167+
}
168+
169+
return opHandlers;
170+
}
171+
}
172+
}

src/Temporalio.Extensions.Hosting/ServiceProviderExtensions.cs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using System.Runtime.ExceptionServices;
66
using System.Threading.Tasks;
77
using Microsoft.Extensions.DependencyInjection;
8+
using NexusRpc;
9+
using NexusRpc.Handlers;
810
using Temporalio.Activities;
911

1012
namespace Temporalio.Extensions.Hosting
@@ -143,5 +145,104 @@ public static ActivityDefinition CreateTemporalActivityDefinition(
143145
}
144146
return ActivityDefinition.Create(method, Invoker);
145147
}
148+
149+
/// <summary>
150+
/// Create <see cref="ServiceHandlerInstance"/> for the given nexus-attributed service handler type.
151+
/// If a service handler method is non-static, this will use the service provider to get the service
152+
/// instance to call the method on.
153+
/// </summary>
154+
/// <param name="provider">Service provider for creating the service instance if the
155+
/// method is non-static.</param>
156+
/// <param name="serviceHandlerType">The type of the Nexus service handler.</param>
157+
/// <returns>Created <see cref="ServiceHandlerInstance"/>.</returns>
158+
internal static ServiceHandlerInstance CreateNexusServiceHandlerInstance(
159+
this IServiceProvider provider,
160+
Type serviceHandlerType) =>
161+
ServiceHandlerInstanceHelper.FromType(
162+
serviceHandlerType,
163+
serviceOperationMethod => new ScopedNexusOperationHandler(
164+
serviceHandlerType,
165+
serviceOperationMethod,
166+
provider));
167+
168+
/// <summary>
169+
/// An operation handler that defers the resolution of the Nexus service handler and the invocation
170+
/// of a Nexus operation to be within a service scope.
171+
/// </summary>
172+
private sealed class ScopedNexusOperationHandler :
173+
IOperationHandler<object?, object?>
174+
{
175+
private readonly Type serviceHandlerType;
176+
private readonly MethodInfo serviceOperationMethod;
177+
private readonly IServiceProvider serviceProvider;
178+
179+
public ScopedNexusOperationHandler(Type serviceHandlerType, MethodInfo serviceOperationMethod, IServiceProvider serviceProvider)
180+
{
181+
this.serviceHandlerType = serviceHandlerType;
182+
this.serviceOperationMethod = serviceOperationMethod;
183+
this.serviceProvider = serviceProvider;
184+
}
185+
186+
public async Task<OperationStartResult<object?>> StartAsync(OperationStartContext context, object? input) =>
187+
await InvokeWithScopeAsync(handler => handler.StartAsync(context, input)).ConfigureAwait(false);
188+
189+
public async Task CancelAsync(OperationCancelContext context) =>
190+
await InvokeWithScopeAsync(handler => handler.CancelAsync(context).ContinueWith(
191+
_ => ValueTuple.Create(),
192+
default,
193+
TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnRanToCompletion,
194+
TaskScheduler.Current)).ConfigureAwait(false);
195+
196+
public async Task<OperationInfo> FetchInfoAsync(OperationFetchInfoContext context) =>
197+
await InvokeWithScopeAsync(handler => handler.FetchInfoAsync(context)).ConfigureAwait(false);
198+
199+
public async Task<object?> FetchResultAsync(OperationFetchResultContext context) =>
200+
await InvokeWithScopeAsync(handler => handler.FetchResultAsync(context)).ConfigureAwait(false);
201+
202+
private async Task<T> InvokeWithScopeAsync<T>(Func<IOperationHandler<object?, object?>, Task<T>> handlerInvoker)
203+
{
204+
#if NET6_0_OR_GREATER
205+
AsyncServiceScope scope = this.serviceProvider.CreateAsyncScope();
206+
#else
207+
IServiceScope scope = this.serviceProvider.CreateScope();
208+
#endif
209+
210+
try
211+
{
212+
object handler;
213+
try
214+
{
215+
// Create the instance if not static and not already created
216+
var serviceHandlerInstance = this.serviceOperationMethod.IsStatic
217+
? null
218+
: scope.ServiceProvider.GetRequiredService(this.serviceHandlerType);
219+
220+
handler = this.serviceOperationMethod.Invoke(serviceHandlerInstance, null) ??
221+
throw new ArgumentException("Operation handler was null");
222+
}
223+
catch (TargetInvocationException e)
224+
{
225+
#if NET6_0_OR_GREATER
226+
ExceptionDispatchInfo.Capture(e.InnerException!).Throw();
227+
#else
228+
ExceptionDispatchInfo.Capture(e.InnerException).Throw();
229+
#endif
230+
// Unreachable
231+
throw new InvalidOperationException("Unreachable");
232+
}
233+
234+
var genericHandler = OperationHandler.WrapAsGenericHandler(handler, this.serviceOperationMethod.ReturnType);
235+
return await handlerInvoker(genericHandler).ConfigureAwait(false);
236+
}
237+
finally
238+
{
239+
#if NET6_0_OR_GREATER
240+
await scope.DisposeAsync().ConfigureAwait(false);
241+
#else
242+
scope.Dispose();
243+
#endif
244+
}
245+
}
246+
}
146247
}
147248
}

0 commit comments

Comments
 (0)