|
4 | 4 | using System;
|
5 | 5 | using System.Runtime.InteropServices;
|
6 | 6 | using System.Text;
|
| 7 | +using System.Collections.Generic; |
| 8 | +using System.Text.Json; |
7 | 9 |
|
8 | 10 |
|
9 | 11 | #nullable enable
|
10 | 12 | namespace Regorus
|
11 | 13 | {
|
| 14 | + /// <summary> |
| 15 | + /// Delegate for callback functions that can be invoked from Rego policies |
| 16 | + /// </summary> |
| 17 | + /// <param name="payload">Deserialized JSON object containing the payload from Rego</param> |
| 18 | + /// <returns>Object that will be serialized to JSON and converted to a Rego value</returns> |
| 19 | + public delegate object RegoCallback(object payload); |
| 20 | + |
12 | 21 | public unsafe sealed class Engine : System.IDisposable
|
13 | 22 | {
|
14 | 23 | private Regorus.Internal.RegorusEngine* E;
|
15 | 24 | // Detect redundant Dispose() calls in a thread-safe manner.
|
16 | 25 | // _isDisposed == 0 means Dispose(bool) has not been called yet.
|
17 | 26 | // _isDisposed == 1 means Dispose(bool) has been already called.
|
18 | 27 | private int isDisposed;
|
| 28 | + |
| 29 | + // Store callback delegates to prevent garbage collection |
| 30 | + private readonly Dictionary<string, (Internal.RegorusCallbackDelegate Delegate, GCHandle Handle)> callbackDelegates |
| 31 | + = new Dictionary<string, (Internal.RegorusCallbackDelegate, GCHandle)>(); |
| 32 | + |
| 33 | + // Store user callbacks |
| 34 | + private readonly Dictionary<string, RegoCallback> callbacks = new Dictionary<string, RegoCallback>(); |
| 35 | + |
| 36 | + // JSON serialization options |
| 37 | + private static readonly JsonSerializerOptions JsonOptions = new JsonSerializerOptions |
| 38 | + { |
| 39 | + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, |
| 40 | + WriteIndented = false |
| 41 | + }; |
19 | 42 |
|
20 | 43 | public Engine()
|
21 | 44 | {
|
@@ -51,7 +74,11 @@ void Dispose(bool disposing)
|
51 | 74 | // and unmanaged resources.
|
52 | 75 | if (disposing)
|
53 | 76 | {
|
54 |
| - // No managed resource to dispose. |
| 77 | + // Unregister all callbacks |
| 78 | + foreach (var name in new List<string>(callbackDelegates.Keys)) |
| 79 | + { |
| 80 | + UnregisterCallback(name); |
| 81 | + } |
55 | 82 | }
|
56 | 83 |
|
57 | 84 | // Call the appropriate methods to clean up
|
@@ -202,7 +229,139 @@ public void SetGatherPrints(bool enable)
|
202 | 229 | return CheckAndDropResult(Regorus.Internal.API.regorus_engine_take_prints(E));
|
203 | 230 | }
|
204 | 231 |
|
205 |
| - |
| 232 | + // Generic callback handler that routes to the appropriate user-provided callback |
| 233 | + private static unsafe byte* CallbackHandler(byte* payloadPtr, void* contextPtr) |
| 234 | + { |
| 235 | + try |
| 236 | + { |
| 237 | + // Context pointer contains the engine instance and callback name |
| 238 | + var context = GCHandle.FromIntPtr(new IntPtr(contextPtr)); |
| 239 | + var contextData = (CallbackContext)context.Target!; |
| 240 | + |
| 241 | + if (contextData == null || contextData.Engine == null) |
| 242 | + { |
| 243 | + return null; |
| 244 | + } |
| 245 | + |
| 246 | + // Convert the payload to a string |
| 247 | + var payload = Marshal.PtrToStringUTF8(new IntPtr(payloadPtr)); |
| 248 | + if (payload == null) |
| 249 | + { |
| 250 | + return null; |
| 251 | + } |
| 252 | + |
| 253 | + // Deserialize the payload to an object |
| 254 | + var payloadObject = JsonSerializer.Deserialize<object>(payload, JsonOptions); |
| 255 | + if (payloadObject == null) |
| 256 | + { |
| 257 | + return null; |
| 258 | + } |
| 259 | + |
| 260 | + // Get the user callback |
| 261 | + if (!contextData.Engine.callbacks.TryGetValue(contextData.CallbackName, out var callback)) |
| 262 | + { |
| 263 | + return null; |
| 264 | + } |
| 265 | + |
| 266 | + // Call the user callback |
| 267 | + var result = callback(payloadObject); |
| 268 | + |
| 269 | + if (result == null) |
| 270 | + { |
| 271 | + return null; |
| 272 | + } |
| 273 | + |
| 274 | + // Always serialize the result to JSON, even if it's a string |
| 275 | + string jsonResult = JsonSerializer.Serialize(result, JsonOptions); |
| 276 | + |
| 277 | + // Convert the result back to a C string that Rust will free |
| 278 | + return (byte*)Marshal.StringToCoTaskMemUTF8(jsonResult).ToPointer(); |
| 279 | + } |
| 280 | + catch |
| 281 | + { |
| 282 | + return null; |
| 283 | + } |
| 284 | + } |
| 285 | + |
| 286 | + private class CallbackContext |
| 287 | + { |
| 288 | + public Engine Engine { get; set; } |
| 289 | + public string CallbackName { get; set; } |
| 290 | + |
| 291 | + public CallbackContext(Engine engine, string name) |
| 292 | + { |
| 293 | + Engine = engine; |
| 294 | + CallbackName = name; |
| 295 | + } |
| 296 | + } |
| 297 | + |
| 298 | + /// <summary> |
| 299 | + /// Register a callback function that can be invoked from Rego policies |
| 300 | + /// </summary> |
| 301 | + /// <param name="name">Name of the callback function to register</param> |
| 302 | + /// <param name="callback">Callback function to be invoked</param> |
| 303 | + /// <returns>True if registration succeeded, otherwise false</returns> |
| 304 | + public bool RegisterCallback(string name, RegoCallback callback) |
| 305 | + { |
| 306 | + if (string.IsNullOrEmpty(name) || callback == null) |
| 307 | + { |
| 308 | + return false; |
| 309 | + } |
| 310 | + |
| 311 | + // Store the callback in our dictionary |
| 312 | + callbacks[name] = callback; |
| 313 | + |
| 314 | + // Create a context object and GCHandle |
| 315 | + var contextData = new CallbackContext(this, name); |
| 316 | + var contextHandle = GCHandle.Alloc(contextData); |
| 317 | + var contextPtr = GCHandle.ToIntPtr(contextHandle); |
| 318 | + |
| 319 | + // Create a delegate for the callback handler |
| 320 | + var callbackDelegate = new Internal.RegorusCallbackDelegate(CallbackHandler); |
| 321 | + |
| 322 | + // Store the delegate to prevent garbage collection |
| 323 | + callbackDelegates[name] = (callbackDelegate, contextHandle); |
| 324 | + |
| 325 | + // Register the callback with the native code |
| 326 | + var nameBytes = NullTerminatedUTF8Bytes(name); |
| 327 | + fixed (byte* namePtr = nameBytes) |
| 328 | + { |
| 329 | + var result = Internal.API.regorus_register_callback(namePtr, callbackDelegate, (void*)contextPtr); |
| 330 | + return result == Internal.RegorusStatus.RegorusStatusOk; |
| 331 | + } |
| 332 | + } |
| 333 | + |
| 334 | + /// <summary> |
| 335 | + /// Unregister a previously registered callback function |
| 336 | + /// </summary> |
| 337 | + /// <param name="name">Name of the callback function to unregister</param> |
| 338 | + /// <returns>True if unregistration succeeded, otherwise false</returns> |
| 339 | + public bool UnregisterCallback(string name) |
| 340 | + { |
| 341 | + if (string.IsNullOrEmpty(name)) |
| 342 | + { |
| 343 | + return false; |
| 344 | + } |
| 345 | + |
| 346 | + // Remove the callback from our dictionary |
| 347 | + callbacks.Remove(name); |
| 348 | + |
| 349 | + // Unregister the callback from the native code |
| 350 | + var nameBytes = NullTerminatedUTF8Bytes(name); |
| 351 | + fixed (byte* namePtr = nameBytes) |
| 352 | + { |
| 353 | + var result = Internal.API.regorus_unregister_callback(namePtr); |
| 354 | + |
| 355 | + // Free the GCHandle if we have it |
| 356 | + if (callbackDelegates.TryGetValue(name, out var delegateInfo)) |
| 357 | + { |
| 358 | + delegateInfo.Handle.Free(); |
| 359 | + callbackDelegates.Remove(name); |
| 360 | + } |
| 361 | + |
| 362 | + return result == Internal.RegorusStatus.RegorusStatusOk; |
| 363 | + } |
| 364 | + } |
206 | 365 |
|
207 | 366 | string? StringFromUTF8(IntPtr ptr)
|
208 | 367 | {
|
|
0 commit comments