diff --git a/build.gradle b/build.gradle index 928da15cb..d1c2ad68f 100644 --- a/build.gradle +++ b/build.gradle @@ -63,6 +63,7 @@ dependencies { compile group: 'io.micrometer', name: 'micrometer-core', version: '1.1.2' compile group: 'javax.annotation', name: 'javax.annotation-api', version: '1.3.2' compile group: 'com.auth0', name: 'java-jwt', version:'3.10.2' + compile group: 'io.opentelemetry', name: 'opentelemetry-sdk', version: '1.1.0' testCompile group: 'junit', name: 'junit', version: '4.12' testCompile group: 'com.googlecode.junit-toolbox', name: 'junit-toolbox', version: '2.4' diff --git a/src/main/java/com/uber/cadence/context/ContextPropagator.java b/src/main/java/com/uber/cadence/context/ContextPropagator.java index 3a618f96c..d302ffdaa 100644 --- a/src/main/java/com/uber/cadence/context/ContextPropagator.java +++ b/src/main/java/com/uber/cadence/context/ContextPropagator.java @@ -23,6 +23,9 @@ * Context Propagators are used to propagate information from workflow to activity, workflow to * child workflow, and workflow to child thread (using {@link com.uber.cadence.workflow.Async}). * + *

It is important to note that all threads share one ContextPropagator instance, so your + * implementation must be thread-safe and store any state in ThreadLocal variables. + * *

A sample ContextPropagator that copies all {@link org.slf4j.MDC} entries starting * with a given prefix along the code path looks like this: * @@ -136,4 +139,31 @@ public interface ContextPropagator { /** Sets the current context */ void setCurrentContext(Object context); + + /** + * This is a lifecycle method, called after the context has been propagated to the + * workflow/activity thread but the workflow/activity has not yet started. + */ + default void setUp() { + // No-op + } + + /** + * This is a lifecycle method, called after the workflow/activity has completed. If the method + * finished without exception, {@code successful} will be true. Otherwise, it will be false and + * {@link #onError(Throwable)} will have already been called. + */ + default void finish() { + // No-op + } + + /** + * This is a lifecycle method, called when the workflow/activity finishes by throwing an unhandled + * exception. {@link #finish()} is called after this method. + * + * @param t The unhandled exception that caused the workflow/activity to terminate + */ + default void onError(Throwable t) { + // No-op + } } diff --git a/src/main/java/com/uber/cadence/context/OpenTelemetryContextPropagator.java b/src/main/java/com/uber/cadence/context/OpenTelemetryContextPropagator.java new file mode 100644 index 000000000..927d3c71f --- /dev/null +++ b/src/main/java/com/uber/cadence/context/OpenTelemetryContextPropagator.java @@ -0,0 +1,144 @@ +/* + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.uber.cadence.context; + +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.TextMapGetter; +import io.opentelemetry.context.propagation.TextMapPropagator; +import io.opentelemetry.context.propagation.TextMapSetter; +import java.nio.charset.Charset; +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nullable; +import org.slf4j.MDC; + +public class OpenTelemetryContextPropagator implements ContextPropagator { + + private static final TextMapPropagator w3cTraceContextPropagator = + W3CTraceContextPropagator.getInstance(); + private static final TextMapPropagator w3cBaggagePropagator = W3CBaggagePropagator.getInstance(); + private static ThreadLocal currentContextOtelScope = new ThreadLocal<>(); + private static ThreadLocal currentOtelSpan = new ThreadLocal<>(); + private static ThreadLocal currentOtelScope = new ThreadLocal<>(); + private static ThreadLocal> otelKeySet = new ThreadLocal<>(); + private static final TextMapSetter> setter = Map::put; + private static final TextMapGetter> getter = + new TextMapGetter>() { + @Override + public Iterable keys(Map carrier) { + return otelKeySet.get(); + } + + @Nullable + @Override + public String get(Map carrier, String key) { + return MDC.get(key); + } + }; + + @Override + public String getName() { + return "OpenTelemetry"; + } + + @Override + public Map serializeContext(Object context) { + Map serializedContext = new HashMap<>(); + Map contextMap = (Map) context; + if (contextMap != null) { + for (Map.Entry entry : contextMap.entrySet()) { + serializedContext.put(entry.getKey(), entry.getValue().getBytes(Charset.defaultCharset())); + } + } + return serializedContext; + } + + @Override + public Object deserializeContext(Map context) { + Map contextMap = new HashMap<>(); + for (Map.Entry entry : context.entrySet()) { + contextMap.put(entry.getKey(), new String(entry.getValue(), Charset.defaultCharset())); + } + return contextMap; + } + + @Override + public Object getCurrentContext() { + Map carrier = new HashMap<>(); + w3cTraceContextPropagator.inject(Context.current(), carrier, setter); + w3cBaggagePropagator.inject(Context.current(), carrier, setter); + return carrier; + } + + @Override + public void setCurrentContext(Object context) { + Map contextMap = (Map) context; + if (contextMap != null) { + for (Map.Entry entry : contextMap.entrySet()) { + MDC.put(entry.getKey(), entry.getValue()); + } + otelKeySet.set(contextMap.keySet()); + } + } + + @Override + @SuppressWarnings("MustBeClosedChecker") + public void setUp() { + Context context = + Baggage.fromContext(w3cBaggagePropagator.extract(Context.current(), null, getter)) + .toBuilder() + .build() + .storeInContext(w3cTraceContextPropagator.extract(Context.current(), null, getter)); + + currentContextOtelScope.set(context.makeCurrent()); + + Span span = + GlobalOpenTelemetry.getTracer("cadence-client") + .spanBuilder("cadence.workflow") + .setParent(context) + .setSpanKind(SpanKind.CLIENT) + .startSpan(); + + Scope scope = span.makeCurrent(); + currentOtelSpan.set(span); + currentOtelScope.set(scope); + } + + @Override + public void finish() { + Scope scope = currentOtelScope.get(); + if (scope != null) { + scope.close(); + } + Span span = currentOtelSpan.get(); + if (span != null) { + span.end(); + } + Scope contextScope = currentContextOtelScope.get(); + if (contextScope != null) { + contextScope.close(); + } + } +} diff --git a/src/main/java/com/uber/cadence/internal/common/InternalUtils.java b/src/main/java/com/uber/cadence/internal/common/InternalUtils.java index 520d8efbb..dc7341c76 100644 --- a/src/main/java/com/uber/cadence/internal/common/InternalUtils.java +++ b/src/main/java/com/uber/cadence/internal/common/InternalUtils.java @@ -27,6 +27,7 @@ import com.uber.cadence.SearchAttributes; import com.uber.cadence.TaskList; import com.uber.cadence.TaskListKind; +import com.uber.cadence.context.ContextPropagator; import com.uber.cadence.converter.DataConverter; import com.uber.cadence.converter.JsonDataConverter; import com.uber.cadence.internal.worker.Shutdownable; @@ -39,6 +40,8 @@ import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + import org.apache.thrift.TDeserializer; import org.apache.thrift.TException; import org.apache.thrift.TSerializer; @@ -251,6 +254,21 @@ public static List DeserializeFromBlobDataToHistoryEvents(List FilterForPropagator( + Map data, + ContextPropagator propagator + ) { + return data + .entrySet() + .stream() + .filter(e -> e.getKey().startsWith(propagator.getName())) + .collect( + Collectors.toMap( + e -> e.getKey().substring(propagator.getName().length() + 1), + Map.Entry::getValue)); + } + /** Prohibit instantiation */ private InternalUtils() {} } diff --git a/src/main/java/com/uber/cadence/internal/context/ContextThreadLocal.java b/src/main/java/com/uber/cadence/internal/context/ContextThreadLocal.java index 227124170..b4423b923 100644 --- a/src/main/java/com/uber/cadence/internal/context/ContextThreadLocal.java +++ b/src/main/java/com/uber/cadence/internal/context/ContextThreadLocal.java @@ -24,10 +24,14 @@ import java.util.List; import java.util.Map; import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -/** This class holds the current set of context propagators */ +/** This class holds the current set of context propagators. */ public class ContextThreadLocal { + private static final Logger log = LoggerFactory.getLogger(ContextThreadLocal.class); + private static WorkflowThreadLocal> contextPropagators = WorkflowThreadLocal.withInitial( new Supplier>() { @@ -37,7 +41,7 @@ public List get() { } }); - /** Sets the list of context propagators for the thread */ + /** Sets the list of context propagators for the thread. */ public static void setContextPropagators(List propagators) { if (propagators == null || propagators.isEmpty()) { return; @@ -57,6 +61,11 @@ public static Map getCurrentContextForPropagation() { return contextData; } + /** + * Injects the context data into the thread for each configured context propagator. + * + * @param contextData The context data received from the server. + */ public static void propagateContextToCurrentThread(Map contextData) { if (contextData == null || contextData.isEmpty()) { return; @@ -67,4 +76,44 @@ public static void propagateContextToCurrentThread(Map contextDa } } } + + /** Calls {@link ContextPropagator#setUp()} for each propagator. */ + public static void setUpContextPropagators() { + for (ContextPropagator propagator : contextPropagators.get()) { + try { + propagator.setUp(); + } catch (Throwable t) { + // Don't let an error in one propagator block the others + log.error("Error calling setUp() on a contextpropagator", t); + } + } + } + + /** + * Calls {@link ContextPropagator#onError(Throwable)} for each propagator. + * + * @param t The Throwable that caused the workflow/activity to finish. + */ + public static void onErrorContextPropagators(Throwable t) { + for (ContextPropagator propagator : contextPropagators.get()) { + try { + propagator.onError(t); + } catch (Throwable t1) { + // Don't let an error in one propagator block the others + log.error("Error calling onError() on a contextpropagator", t1); + } + } + } + + /** Calls {@link ContextPropagator#finish()} for each propagator. */ + public static void finishContextPropagators() { + for (ContextPropagator propagator : contextPropagators.get()) { + try { + propagator.finish(); + } catch (Throwable t) { + // Don't let an error in one propagator block the others + log.error("Error calling finish() on a contextpropagator", t); + } + } + } } diff --git a/src/main/java/com/uber/cadence/internal/replay/WorkflowContext.java b/src/main/java/com/uber/cadence/internal/replay/WorkflowContext.java index 5f1d8f16f..d18c8651c 100644 --- a/src/main/java/com/uber/cadence/internal/replay/WorkflowContext.java +++ b/src/main/java/com/uber/cadence/internal/replay/WorkflowContext.java @@ -19,10 +19,13 @@ import com.uber.cadence.*; import com.uber.cadence.context.ContextPropagator; +import com.uber.cadence.internal.common.InternalUtils; + import java.nio.ByteBuffer; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; final class WorkflowContext { @@ -166,7 +169,10 @@ Map getPropagatedContexts() { Map contextData = new HashMap<>(); for (ContextPropagator propagator : contextPropagators) { - contextData.put(propagator.getName(), propagator.deserializeContext(headerData)); + // Only send the context propagator the fields that belong to them + // Change the map from MyPropagator:foo -> bar to foo -> bar + Map filteredData = InternalUtils.FilterForPropagator(headerData, propagator); + contextData.put(propagator.getName(), propagator.deserializeContext(filteredData)); } return contextData; diff --git a/src/main/java/com/uber/cadence/internal/sync/SyncDecisionContext.java b/src/main/java/com/uber/cadence/internal/sync/SyncDecisionContext.java index 3322f3e10..8fb4b0833 100644 --- a/src/main/java/com/uber/cadence/internal/sync/SyncDecisionContext.java +++ b/src/main/java/com/uber/cadence/internal/sync/SyncDecisionContext.java @@ -72,6 +72,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -453,7 +454,18 @@ private Map extractContextsAndConvertToBytes( } Map result = new HashMap<>(); for (ContextPropagator propagator : contextPropagators) { - result.putAll(propagator.serializeContext(propagator.getCurrentContext())); + // Get the serialized context from the propagator + Map serializedContext = + propagator.serializeContext(propagator.getCurrentContext()); + // Namespace each entry in case of overlaps, so foo -> bar becomes MyPropagator:foo -> bar + Map namespacedSerializedContext = + serializedContext + .entrySet() + .stream() + .collect( + Collectors.toMap( + e -> propagator.getName() + ":" + e.getKey(), Map.Entry::getValue)); + result.putAll(namespacedSerializedContext); } return result; } diff --git a/src/main/java/com/uber/cadence/internal/sync/WorkflowStubImpl.java b/src/main/java/com/uber/cadence/internal/sync/WorkflowStubImpl.java index 47e4ff720..b115f2f6f 100644 --- a/src/main/java/com/uber/cadence/internal/sync/WorkflowStubImpl.java +++ b/src/main/java/com/uber/cadence/internal/sync/WorkflowStubImpl.java @@ -52,6 +52,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; class WorkflowStubImpl implements WorkflowStub { @@ -204,7 +205,18 @@ private Map extractContextsAndConvertToBytes( } Map result = new HashMap<>(); for (ContextPropagator propagator : contextPropagators) { - result.putAll(propagator.serializeContext(propagator.getCurrentContext())); + // Get the serialized context from the propagator + Map serializedContext = + propagator.serializeContext(propagator.getCurrentContext()); + // Namespace each entry in case of overlaps, so foo -> bar becomes MyPropagator:foo -> bar + Map namespacedSerializedContext = + serializedContext + .entrySet() + .stream() + .collect( + Collectors.toMap( + k -> propagator.getName() + ":" + k.getKey(), Map.Entry::getValue)); + result.putAll(namespacedSerializedContext); } return result; } diff --git a/src/main/java/com/uber/cadence/internal/sync/WorkflowThreadImpl.java b/src/main/java/com/uber/cadence/internal/sync/WorkflowThreadImpl.java index 9b7f5891f..bee091a2a 100644 --- a/src/main/java/com/uber/cadence/internal/sync/WorkflowThreadImpl.java +++ b/src/main/java/com/uber/cadence/internal/sync/WorkflowThreadImpl.java @@ -91,6 +91,7 @@ public void run() { // Repopulate the context(s) ContextThreadLocal.setContextPropagators(this.contextPropagators); ContextThreadLocal.propagateContextToCurrentThread(this.propagatedContexts); + ContextThreadLocal.setUpContextPropagators(); try { // initialYield blocks thread until the first runUntilBlocked is called. @@ -99,6 +100,7 @@ public void run() { cancellationScope.run(); } catch (DestroyWorkflowThreadError e) { if (!threadContext.isDestroyRequested()) { + ContextThreadLocal.onErrorContextPropagators(e); threadContext.setUnhandledException(e); } } catch (Error e) { @@ -111,9 +113,11 @@ public void run() { log.error( String.format("Workflow thread \"%s\" run failed with Error:\n%s", name, stackTrace)); } + ContextThreadLocal.onErrorContextPropagators(e); threadContext.setUnhandledException(e); } catch (CancellationException e) { if (!isCancelRequested()) { + ContextThreadLocal.onErrorContextPropagators(e); threadContext.setUnhandledException(e); } if (log.isDebugEnabled()) { @@ -130,8 +134,10 @@ public void run() { "Workflow thread \"%s\" run failed with unhandled exception:\n%s", name, stackTrace)); } + ContextThreadLocal.onErrorContextPropagators(e); threadContext.setUnhandledException(e); } finally { + ContextThreadLocal.finishContextPropagators(); DeterministicRunnerImpl.setCurrentThreadInternal(null); threadContext.setStatus(Status.DONE); thread.setName(originalName); diff --git a/src/main/java/com/uber/cadence/internal/worker/ActivityWorker.java b/src/main/java/com/uber/cadence/internal/worker/ActivityWorker.java index f1a100fb3..3b4a07c1f 100644 --- a/src/main/java/com/uber/cadence/internal/worker/ActivityWorker.java +++ b/src/main/java/com/uber/cadence/internal/worker/ActivityWorker.java @@ -41,6 +41,7 @@ import java.util.Objects; import java.util.concurrent.CancellationException; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import org.apache.thrift.TException; import org.slf4j.MDC; @@ -161,7 +162,11 @@ public void handle(PollForActivityTaskResponse task) throws Exception { Stopwatch sw = metricsScope.timer(MetricsType.ACTIVITY_RESP_LATENCY).start(); sendReply(task, new Result(null, null, cancelledRequest), metricsScope); sw.stop(); + onErrorContextPropagation(e); + } catch (Exception e) { + onErrorContextPropagation(e); } finally { + finishContextPropagation(); MDC.remove(LoggerTag.ACTIVITY_ID); MDC.remove(LoggerTag.ACTIVITY_TYPE); MDC.remove(LoggerTag.WORKFLOW_ID); @@ -188,7 +193,31 @@ void propagateContext(PollForActivityTaskResponse response) { }); for (ContextPropagator propagator : options.getContextPropagators()) { - propagator.setCurrentContext(propagator.deserializeContext(headerData)); + // Only send the context propagator the fields that belong to them + // Change the map from MyPropagator:foo -> bar to foo -> bar + Map filteredData = + headerData + .entrySet() + .stream() + .filter(e -> e.getKey().startsWith(propagator.getName())) + .collect( + Collectors.toMap( + e -> e.getKey().substring(propagator.getName().length() + 1), + Map.Entry::getValue)); + propagator.setCurrentContext(propagator.deserializeContext(filteredData)); + propagator.setUp(); + } + } + + void onErrorContextPropagation(Exception error) { + for (ContextPropagator propagator : options.getContextPropagators()) { + propagator.onError(error); + } + } + + void finishContextPropagation() { + for (ContextPropagator propagator : options.getContextPropagators()) { + propagator.finish(); } } diff --git a/src/main/java/com/uber/cadence/internal/worker/LocalActivityWorker.java b/src/main/java/com/uber/cadence/internal/worker/LocalActivityWorker.java index c08870adf..08e765aad 100644 --- a/src/main/java/com/uber/cadence/internal/worker/LocalActivityWorker.java +++ b/src/main/java/com/uber/cadence/internal/worker/LocalActivityWorker.java @@ -22,6 +22,7 @@ import com.uber.cadence.MarkerRecordedEventAttributes; import com.uber.cadence.PollForActivityTaskResponse; import com.uber.cadence.common.RetryOptions; +import com.uber.cadence.context.ContextPropagator; import com.uber.cadence.internal.common.LocalActivityMarkerData; import com.uber.cadence.internal.metrics.MetricsTag; import com.uber.cadence.internal.metrics.MetricsType; @@ -37,6 +38,7 @@ import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.LongSupplier; +import java.util.stream.Collectors; public final class LocalActivityWorker extends SuspendableWorkerBase { @@ -225,9 +227,19 @@ private void propagateContext(ExecuteLocalActivityParameters params) { } private void restoreContext(Map context) { - options - .getContextPropagators() - .forEach( - propagator -> propagator.setCurrentContext(propagator.deserializeContext(context))); + for (ContextPropagator propagator : options.getContextPropagators()) { + // Only send the context propagator the fields that belong to them + // Change the map from MyPropagator:foo -> bar to foo -> bar + Map filteredData = + context + .entrySet() + .stream() + .filter(e -> e.getKey().startsWith(propagator.getName())) + .collect( + Collectors.toMap( + e -> e.getKey().substring(propagator.getName().length() + 1), + Map.Entry::getValue)); + propagator.setCurrentContext(propagator.deserializeContext(filteredData)); + } } } diff --git a/src/main/java/com/uber/cadence/worker/Worker.java b/src/main/java/com/uber/cadence/worker/Worker.java index 7c0b714ea..0e2073151 100644 --- a/src/main/java/com/uber/cadence/worker/Worker.java +++ b/src/main/java/com/uber/cadence/worker/Worker.java @@ -23,6 +23,7 @@ import com.uber.cadence.client.WorkflowClient; import com.uber.cadence.common.WorkflowExecutionHistory; import com.uber.cadence.context.ContextPropagator; +import com.uber.cadence.context.OpenTelemetryContextPropagator; import com.uber.cadence.converter.DataConverter; import com.uber.cadence.internal.common.InternalUtils; import com.uber.cadence.internal.metrics.MetricsTag; @@ -36,6 +37,7 @@ import com.uber.m3.tally.Scope; import com.uber.m3.util.ImmutableMap; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.concurrent.CompletableFuture; @@ -82,6 +84,7 @@ public final class Worker implements Suspendable { .getOptions() .getMetricsScope() .tagged(ImmutableMap.of(MetricsTag.TASK_LIST, taskList)); + contextPropagators = new ArrayList<>(contextPropagators); SingleWorkerOptions activityOptions = SingleWorkerOptions.newBuilder() diff --git a/src/test/java/com/uber/cadence/context/ContextTests.java b/src/test/java/com/uber/cadence/context/ContextTests.java new file mode 100644 index 000000000..723f75dcf --- /dev/null +++ b/src/test/java/com/uber/cadence/context/ContextTests.java @@ -0,0 +1,514 @@ +/* + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.uber.cadence.context; + +import static org.junit.Assert.*; + +import com.uber.cadence.activity.ActivityMethod; +import com.uber.cadence.activity.ActivityOptions; +import com.uber.cadence.client.WorkflowClient; +import com.uber.cadence.client.WorkflowClientOptions; +import com.uber.cadence.client.WorkflowException; +import com.uber.cadence.client.WorkflowOptions; +import com.uber.cadence.internal.testing.WorkflowTestingTest.ChildWorkflow; +import com.uber.cadence.testing.TestEnvironmentOptions; +import com.uber.cadence.testing.TestWorkflowEnvironment; +import com.uber.cadence.worker.Worker; +import com.uber.cadence.workflow.Async; +import com.uber.cadence.workflow.ChildWorkflowOptions; +import com.uber.cadence.workflow.Promise; +import com.uber.cadence.workflow.SignalMethod; +import com.uber.cadence.workflow.Workflow; +import com.uber.cadence.workflow.WorkflowMethod; +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.api.trace.TraceFlags; +import io.opentelemetry.api.trace.TraceState; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestWatcher; +import org.junit.rules.Timeout; +import org.junit.runner.Description; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.MDC; + +public class ContextTests { + private static final Logger log = LoggerFactory.getLogger(ContextTests.class); + + @Rule public Timeout globalTimeout = Timeout.seconds(5000); + + @Rule + public TestWatcher watchman = + new TestWatcher() { + @Override + protected void failed(Throwable e, Description description) { + System.err.println(testEnvironment.getDiagnostics()); + } + }; + + private static final String TASK_LIST = "test-workflow"; + + private TestWorkflowEnvironment testEnvironment; + + @Before + public void setUp() { + WorkflowClientOptions clientOptions = + WorkflowClientOptions.newBuilder() + .setContextPropagators( + Arrays.asList(new TestContextPropagator(), new OpenTelemetryContextPropagator())) + .build(); + TestEnvironmentOptions options = + new TestEnvironmentOptions.Builder().setWorkflowClientOptions(clientOptions).build(); + testEnvironment = TestWorkflowEnvironment.newInstance(options); + } + + @After + public void tearDown() { + testEnvironment.close(); + } + + public interface TestWorkflow { + + @WorkflowMethod(executionStartToCloseTimeoutSeconds = 3600 * 24, taskList = TASK_LIST) + String workflow1(String input); + } + + public interface ParentWorkflow { + + @WorkflowMethod(executionStartToCloseTimeoutSeconds = 3600 * 24, taskList = TASK_LIST) + String workflow(String input); + + @SignalMethod + void signal(String value); + } + + public interface TestActivity { + + @ActivityMethod(scheduleToCloseTimeoutSeconds = 3600) + String activity1(String input); + } + + public static class TestContextPropagator implements ContextPropagator { + + @Override + public String getName() { + return "TestContextPropagator::withSomeColons"; + } + + @Override + public Map serializeContext(Object context) { + String testKey = (String) context; + if (testKey != null) { + return Collections.singletonMap("test", testKey.getBytes(StandardCharsets.UTF_8)); + } else { + return Collections.emptyMap(); + } + } + + @Override + public Object deserializeContext(Map context) { + if (context.containsKey("test")) { + return new String(context.get("test"), StandardCharsets.UTF_8); + } else { + return null; + } + } + + @Override + public Object getCurrentContext() { + return MDC.get("test"); + } + + @Override + public void setCurrentContext(Object context) { + MDC.put("test", String.valueOf(context)); + } + } + + public static class ContextPropagationWorkflowImpl implements TestWorkflow { + + @Override + public String workflow1(String input) { + // The test value should be in the MDC + return MDC.get("test"); + } + } + + @Test() + public void testWorkflowContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes(ContextPropagationWorkflowImpl.class); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = + new WorkflowOptions.Builder() + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + String result = workflow.workflow1("input1"); + assertEquals("testing123", result); + } + + public static class ContextPropagationParentWorkflowImpl implements ParentWorkflow { + + @Override + public String workflow(String input) { + // Get the MDC value + String mdcValue = MDC.get("test"); + + // Fire up a child workflow + ChildWorkflowOptions options = + new ChildWorkflowOptions.Builder() + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + ChildWorkflow child = Workflow.newChildWorkflowStub(ChildWorkflow.class, options); + + String result = child.workflow(mdcValue, Workflow.getWorkflowInfo().getWorkflowId()); + return result; + } + + @Override + public void signal(String value) {} + } + + public static class ContextPropagationChildWorkflowImpl implements ChildWorkflow { + + @Override + public String workflow(String input, String parentId) { + String mdcValue = MDC.get("test"); + return input + mdcValue; + } + } + + @Test + public void testChildWorkflowContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes( + ContextPropagationParentWorkflowImpl.class, ContextPropagationChildWorkflowImpl.class); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = + new WorkflowOptions.Builder() + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + ParentWorkflow workflow = client.newWorkflowStub(ParentWorkflow.class, options); + String result = workflow.workflow("input1"); + assertEquals("testing123testing123", result); + } + + public static class ContextPropagationThreadWorkflowImpl implements TestWorkflow { + + @Override + public String workflow1(String input) { + Promise asyncPromise = Async.function(this::async); + return asyncPromise.get(); + } + + private String async() { + return "async" + MDC.get("test"); + } + } + + @Test + public void testThreadContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes(ContextPropagationThreadWorkflowImpl.class); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = + new WorkflowOptions.Builder() + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + String result = workflow.workflow1("input1"); + assertEquals("asynctesting123", result); + } + + public static class ContextActivityImpl implements TestActivity { + @Override + public String activity1(String input) { + return "activity" + MDC.get("test"); + } + } + + public static class ContextPropagationActivityWorkflowImpl implements TestWorkflow { + @Override + public String workflow1(String input) { + ActivityOptions options = + new ActivityOptions.Builder() + .setScheduleToCloseTimeout(Duration.ofSeconds(5)) + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + TestActivity activity = Workflow.newActivityStub(TestActivity.class, options); + return activity.activity1("foo"); + } + } + + @Test + public void testActivityContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes(ContextPropagationActivityWorkflowImpl.class); + worker.registerActivitiesImplementations(new ContextActivityImpl()); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = + new WorkflowOptions.Builder() + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + String result = workflow.workflow1("input1"); + assertEquals("activitytesting123", result); + } + + public static class DefaultContextPropagationActivityWorkflowImpl implements TestWorkflow { + @Override + public String workflow1(String input) { + ActivityOptions options = + new ActivityOptions.Builder().setScheduleToCloseTimeout(Duration.ofSeconds(5)).build(); + TestActivity activity = Workflow.newActivityStub(TestActivity.class, options); + return activity.activity1("foo"); + } + } + + @Test + public void testDefaultActivityContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes(DefaultContextPropagationActivityWorkflowImpl.class); + worker.registerActivitiesImplementations(new ContextActivityImpl()); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = + new WorkflowOptions.Builder() + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + String result = workflow.workflow1("input1"); + assertEquals("activitytesting123", result); + } + + public static class DefaultContextPropagationParentWorkflowImpl implements ParentWorkflow { + + @Override + public String workflow(String input) { + // Get the MDC value + String mdcValue = MDC.get("test"); + + // Fire up a child workflow + ChildWorkflowOptions options = new ChildWorkflowOptions.Builder().build(); + ChildWorkflow child = Workflow.newChildWorkflowStub(ChildWorkflow.class, options); + + String result = child.workflow(mdcValue, Workflow.getWorkflowInfo().getWorkflowId()); + return result; + } + + @Override + public void signal(String value) {} + } + + @Test + public void testDefaultChildWorkflowContextPropagation() { + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes( + DefaultContextPropagationParentWorkflowImpl.class, + ContextPropagationChildWorkflowImpl.class); + testEnvironment.start(); + MDC.put("test", "testing123"); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = + new WorkflowOptions.Builder() + .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .build(); + ParentWorkflow workflow = client.newWorkflowStub(ParentWorkflow.class, options); + String result = workflow.workflow("input1"); + assertEquals("testing123testing123", result); + } + + public static class OpenTelemetryContextPropagationWorkflowImpl implements TestWorkflow { + @Override + public String workflow1(String input) { + if ("fail".equals(input)) { + throw new IllegalArgumentException(); + } else if ("baggage".equals(input)) { + return Baggage.current().toString(); + } else { + SpanContext ctx = Span.current().getSpanContext(); + return ctx.getTraceId() + ":" + ctx.getSpanId() + ":" + ctx.getTraceState().toString(); + } + } + } + + @Test + public void testOpenTelemetryContextPropagation() { + TraceState TRACE_STATE = TraceState.builder().put("foo", "bar").build(); + String TRACE_ID_BASE16 = "ff000000000000000000000000000041"; + String SPAN_ID_BASE16 = "ff00000000000041"; + + Context ctx = + Context.current() + .with( + Span.wrap( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, TraceFlags.getDefault(), TRACE_STATE))); + + Span span = + GlobalOpenTelemetry.getTracer("test-tracer") + .spanBuilder("test-span") + .setParent(ctx) + .setSpanKind(SpanKind.CLIENT) + .startSpan(); + + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes(OpenTelemetryContextPropagationWorkflowImpl.class); + testEnvironment.start(); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = + new WorkflowOptions.Builder() + .setContextPropagators( + Arrays.asList(new TestContextPropagator(), new OpenTelemetryContextPropagator())) + .build(); + + try (Scope scope = span.makeCurrent()) { + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + assertEquals( + TRACE_ID_BASE16 + ":" + SPAN_ID_BASE16 + ":" + TRACE_STATE.toString(), + workflow.workflow1("success")); + } + + try (Scope scope = span.makeCurrent()) { + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + workflow.workflow1("fail"); + } catch (WorkflowException e) { + assertEquals(IllegalArgumentException.class, e.getCause().getClass()); + } finally { + span.end(); + } + } + + @Test + public void testDefaultOpenTelemetryContextPropagation() { + TraceState TRACE_STATE = TraceState.builder().put("foo", "bar").build(); + String TRACE_ID_BASE16 = "ff000000000000000000000000000041"; + String SPAN_ID_BASE16 = "ff00000000000041"; + + Context ctx = + Context.current() + .with( + Span.wrap( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, TraceFlags.getDefault(), TRACE_STATE))); + + Span span = + GlobalOpenTelemetry.getTracer("test-tracer") + .spanBuilder("test-span") + .setParent(ctx) + .setSpanKind(SpanKind.CLIENT) + .startSpan(); + + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes(OpenTelemetryContextPropagationWorkflowImpl.class); + testEnvironment.start(); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = new WorkflowOptions.Builder().build(); + + try (Scope scope = span.makeCurrent()) { + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + assertEquals( + TRACE_ID_BASE16 + ":" + SPAN_ID_BASE16 + ":" + TRACE_STATE.toString(), + workflow.workflow1("success")); + } + + try (Scope scope = span.makeCurrent()) { + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + workflow.workflow1("fail"); + } catch (WorkflowException e) { + assertEquals(IllegalArgumentException.class, e.getCause().getClass()); + } finally { + span.end(); + } + } + + @Test + public void testNoDefaultOpenTelemetryContextPropagation() { + TraceState TRACE_STATE = TraceState.builder().put("foo", "bar").build(); + String TRACE_ID_BASE16 = "ff000000000000000000000000000041"; + String SPAN_ID_BASE16 = "ff00000000000041"; + + Context ctx = + Context.current() + .with( + Span.wrap( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, TraceFlags.getDefault(), TRACE_STATE))); + + Span span = + GlobalOpenTelemetry.getTracer("test-tracer") + .spanBuilder("test-span") + .setParent(ctx) + .setSpanKind(SpanKind.CLIENT) + .startSpan(); + + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes(OpenTelemetryContextPropagationWorkflowImpl.class); + testEnvironment.start(); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = + new WorkflowOptions.Builder().build(); + + try (Scope scope = span.makeCurrent()) { + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + assertNotEquals( + TRACE_ID_BASE16 + ":" + SPAN_ID_BASE16 + ":" + TRACE_STATE.toString(), + workflow.workflow1("success")); + } + } + + @Test + public void testBaggagePropagation() { + Baggage baggage = + Baggage.builder().put("keyFoo1", "valueFoo1").put("keyFoo2", "valueFoo2").build(); + + Worker worker = testEnvironment.newWorker(TASK_LIST); + worker.registerWorkflowImplementationTypes(OpenTelemetryContextPropagationWorkflowImpl.class); + testEnvironment.start(); + WorkflowClient client = testEnvironment.newWorkflowClient(); + WorkflowOptions options = new WorkflowOptions.Builder().build(); + + try (Scope scope = Context.current().with(baggage).makeCurrent()) { + TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); + assertEquals(baggage.toString(), workflow.workflow1("baggage")); + } + } +} diff --git a/src/test/java/com/uber/cadence/internal/testing/WorkflowTestingTest.java b/src/test/java/com/uber/cadence/internal/testing/WorkflowTestingTest.java index e22de1bdd..81214c1bf 100644 --- a/src/test/java/com/uber/cadence/internal/testing/WorkflowTestingTest.java +++ b/src/test/java/com/uber/cadence/internal/testing/WorkflowTestingTest.java @@ -17,6 +17,7 @@ package com.uber.cadence.internal.testing; +import static org.junit.Assert.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -39,7 +40,7 @@ import com.uber.cadence.activity.ActivityMethod; import com.uber.cadence.activity.ActivityOptions; import com.uber.cadence.client.*; -import com.uber.cadence.context.ContextPropagator; +import com.uber.cadence.context.ContextTests; import com.uber.cadence.internal.common.WorkflowExecutionUtils; import com.uber.cadence.testing.SimulatedTimeoutException; import com.uber.cadence.testing.TestEnvironmentOptions; @@ -47,17 +48,14 @@ import com.uber.cadence.worker.Worker; import com.uber.cadence.workflow.ActivityTimeoutException; import com.uber.cadence.workflow.Async; -import com.uber.cadence.workflow.ChildWorkflowOptions; import com.uber.cadence.workflow.ChildWorkflowTimedOutException; import com.uber.cadence.workflow.Promise; import com.uber.cadence.workflow.SignalMethod; import com.uber.cadence.workflow.Workflow; import com.uber.cadence.workflow.WorkflowMethod; -import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CancellationException; @@ -73,7 +71,6 @@ import org.junit.runner.Description; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.slf4j.MDC; public class WorkflowTestingTest { private static final Logger log = LoggerFactory.getLogger(WorkflowTestingTest.class); @@ -99,7 +96,8 @@ public void setUp() { new TestEnvironmentOptions.Builder() .setWorkflowClientOptions( WorkflowClientOptions.newBuilder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) + .setContextPropagators( + Collections.singletonList(new ContextTests.TestContextPropagator())) .build()) .build(); testEnvironment = TestWorkflowEnvironment.newInstance(options); @@ -739,244 +737,4 @@ public void testMockedChildSimulatedTimeout() { assertTrue(e.getCause() instanceof ChildWorkflowTimedOutException); } } - - public static class TestContextPropagator implements ContextPropagator { - - @Override - public String getName() { - return this.getClass().getName(); - } - - @Override - public Map serializeContext(Object context) { - String testKey = (String) context; - if (testKey != null) { - return Collections.singletonMap("test", testKey.getBytes(StandardCharsets.UTF_8)); - } else { - return Collections.emptyMap(); - } - } - - @Override - public Object deserializeContext(Map context) { - if (context.containsKey("test")) { - return new String(context.get("test"), StandardCharsets.UTF_8); - } else { - return null; - } - } - - @Override - public Object getCurrentContext() { - return MDC.get("test"); - } - - @Override - public void setCurrentContext(Object context) { - MDC.put("test", String.valueOf(context)); - } - } - - public static class ContextPropagationWorkflowImpl implements TestWorkflow { - - @Override - public String workflow1(String input) { - // The test value should be in the MDC - return MDC.get("test"); - } - } - - @Test - public void testWorkflowContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_LIST); - worker.registerWorkflowImplementationTypes(ContextPropagationWorkflowImpl.class); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.newWorkflowClient(); - WorkflowOptions options = - new WorkflowOptions.Builder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); - String result = workflow.workflow1("input1"); - assertEquals("testing123", result); - } - - public static class ContextPropagationParentWorkflowImpl implements ParentWorkflow { - - @Override - public String workflow(String input) { - // Get the MDC value - String mdcValue = MDC.get("test"); - - // Fire up a child workflow - ChildWorkflowOptions options = - new ChildWorkflowOptions.Builder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - ChildWorkflow child = Workflow.newChildWorkflowStub(ChildWorkflow.class, options); - - String result = child.workflow(mdcValue, Workflow.getWorkflowInfo().getWorkflowId()); - return result; - } - - @Override - public void signal(String value) {} - } - - public static class ContextPropagationChildWorkflowImpl implements ChildWorkflow { - - @Override - public String workflow(String input, String parentId) { - String mdcValue = MDC.get("test"); - return input + mdcValue; - } - } - - @Test - public void testChildWorkflowContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_LIST); - worker.registerWorkflowImplementationTypes( - ContextPropagationParentWorkflowImpl.class, ContextPropagationChildWorkflowImpl.class); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.newWorkflowClient(); - WorkflowOptions options = - new WorkflowOptions.Builder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - ParentWorkflow workflow = client.newWorkflowStub(ParentWorkflow.class, options); - String result = workflow.workflow("input1"); - assertEquals("testing123testing123", result); - } - - public static class ContextPropagationThreadWorkflowImpl implements TestWorkflow { - - @Override - public String workflow1(String input) { - Promise asyncPromise = Async.function(this::async); - return asyncPromise.get(); - } - - private String async() { - return "async" + MDC.get("test"); - } - } - - @Test - public void testThreadContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_LIST); - worker.registerWorkflowImplementationTypes(ContextPropagationThreadWorkflowImpl.class); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.newWorkflowClient(); - WorkflowOptions options = - new WorkflowOptions.Builder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); - String result = workflow.workflow1("input1"); - assertEquals("asynctesting123", result); - } - - public static class ContextActivityImpl implements TestActivity { - @Override - public String activity1(String input) { - return "activity" + MDC.get("test"); - } - } - - public static class ContextPropagationActivityWorkflowImpl implements TestWorkflow { - @Override - public String workflow1(String input) { - ActivityOptions options = - new ActivityOptions.Builder() - .setScheduleToCloseTimeout(Duration.ofSeconds(5)) - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - TestActivity activity = Workflow.newActivityStub(TestActivity.class, options); - return activity.activity1("foo"); - } - } - - @Test - public void testActivityContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_LIST); - worker.registerWorkflowImplementationTypes(ContextPropagationActivityWorkflowImpl.class); - worker.registerActivitiesImplementations(new ContextActivityImpl()); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.newWorkflowClient(); - WorkflowOptions options = - new WorkflowOptions.Builder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); - String result = workflow.workflow1("input1"); - assertEquals("activitytesting123", result); - } - - public static class DefaultContextPropagationActivityWorkflowImpl implements TestWorkflow { - @Override - public String workflow1(String input) { - ActivityOptions options = - new ActivityOptions.Builder().setScheduleToCloseTimeout(Duration.ofSeconds(5)).build(); - TestActivity activity = Workflow.newActivityStub(TestActivity.class, options); - return activity.activity1("foo"); - } - } - - @Test - public void testDefaultActivityContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_LIST); - worker.registerWorkflowImplementationTypes(DefaultContextPropagationActivityWorkflowImpl.class); - worker.registerActivitiesImplementations(new ContextActivityImpl()); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.newWorkflowClient(); - WorkflowOptions options = - new WorkflowOptions.Builder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - TestWorkflow workflow = client.newWorkflowStub(TestWorkflow.class, options); - String result = workflow.workflow1("input1"); - assertEquals("activitytesting123", result); - } - - public static class DefaultContextPropagationParentWorkflowImpl implements ParentWorkflow { - - @Override - public String workflow(String input) { - // Get the MDC value - String mdcValue = MDC.get("test"); - - // Fire up a child workflow - ChildWorkflowOptions options = new ChildWorkflowOptions.Builder().build(); - ChildWorkflow child = Workflow.newChildWorkflowStub(ChildWorkflow.class, options); - - String result = child.workflow(mdcValue, Workflow.getWorkflowInfo().getWorkflowId()); - return result; - } - - @Override - public void signal(String value) {} - } - - @Test - public void testDefaultChildWorkflowContextPropagation() { - Worker worker = testEnvironment.newWorker(TASK_LIST); - worker.registerWorkflowImplementationTypes( - DefaultContextPropagationParentWorkflowImpl.class, - ContextPropagationChildWorkflowImpl.class); - testEnvironment.start(); - MDC.put("test", "testing123"); - WorkflowClient client = testEnvironment.newWorkflowClient(); - WorkflowOptions options = - new WorkflowOptions.Builder() - .setContextPropagators(Collections.singletonList(new TestContextPropagator())) - .build(); - ParentWorkflow workflow = client.newWorkflowStub(ParentWorkflow.class, options); - String result = workflow.workflow("input1"); - assertEquals("testing123testing123", result); - } }