Skip to content

Add tenant id to lambda context and structured log messages #540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ public ByteArrayOutputStream call(InvocationRequest request) throws Error, Excep
cognitoIdentity,
LambdaEnvironment.FUNCTION_VERSION,
request.getInvokedFunctionArn(),
request.getTenantId(),
clientContext
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class LambdaContext implements Context {
private final long deadlineTimeInMs;
private final CognitoIdentity cognitoIdentity;
private final ClientContext clientContext;
private final String tenantId;
private final LambdaLogger logger;

public LambdaContext(
Expand All @@ -34,6 +35,7 @@ public LambdaContext(
CognitoIdentity identity,
String functionVersion,
String invokedFunctionArn,
String tenantId,
ClientContext clientContext
) {
this.memoryLimit = memoryLimit;
Expand All @@ -46,6 +48,7 @@ public LambdaContext(
this.clientContext = clientContext;
this.functionVersion = functionVersion;
this.invokedFunctionArn = invokedFunctionArn;
this.tenantId = tenantId;
this.logger = com.amazonaws.services.lambda.runtime.LambdaRuntime.getLogger();
}

Expand Down Expand Up @@ -91,6 +94,10 @@ public int getRemainingTimeInMillis() {
return delta > 0 ? delta : 0;
}

public String getTenantId() {
return tenantId;
}

public LambdaLogger getLogger() {
return logger;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ private StructuredLogMessage createLogMessage(String message, LogLevel logLevel)

if (lambdaContext != null) {
msg.AWSRequestId = lambdaContext.getAwsRequestId();
msg.tenantId = lambdaContext.getTenantId();
}
return msg;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ class StructuredLogMessage {
public String message;
public LogLevel level;
public String AWSRequestId;
public String tenantId;
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public class InvocationRequest {
*/
private String cognitoIdentity;

/**
* The tenant ID associated with the request.
*/
private String tenantId;

private byte[] content;

public String getId() {
Expand Down Expand Up @@ -94,6 +99,14 @@ public void setCognitoIdentity(String cognitoIdentity) {
this.cognitoIdentity = cognitoIdentity;
}

public String getTenantId() {
return tenantId;
}

public void setTenantId(String tenantId) {
this.tenantId = tenantId;
}

public byte[] getContent() {
return content;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ static jfieldID contentField;
static jfieldID clientContextField;
static jfieldID cognitoIdentityField;
static jfieldID xrayTraceIdField;
static jfieldID tenantIdField;


jint JNI_OnLoad(JavaVM* vm, void* reserved) {
Expand All @@ -41,6 +42,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
xrayTraceIdField = env->GetFieldID(invocationRequestClass , "xrayTraceId", "Ljava/lang/String;");
clientContextField = env->GetFieldID(invocationRequestClass , "clientContext", "Ljava/lang/String;");
cognitoIdentityField = env->GetFieldID(invocationRequestClass , "cognitoIdentity", "Ljava/lang/String;");
tenantIdField = env->GetFieldID(invocationRequestClass, "tenantId", "Ljava/lang/String;");

return JNI_VERSION;
}
Expand Down Expand Up @@ -106,6 +108,10 @@ JNIEXPORT jobject JNICALL Java_com_amazonaws_services_lambda_runtime_api_client_
CHECK_EXCEPTION(env, env->SetObjectField(invocationRequest, cognitoIdentityField, env->NewStringUTF(response.cognito_identity.c_str())));
}

if(response.tenant_id != ""){
CHECK_EXCEPTION(env, env->SetObjectField(invocationRequest, tenantIdField, env->NewStringUTF(response.tenant_id.c_str())));
}

bytes = reinterpret_cast<const jbyte*>(response.payload.c_str());
CHECK_EXCEPTION(env, jArray = env->NewByteArray(response.payload.length()));
CHECK_EXCEPTION(env, env->SetByteArrayRegion(jArray, 0, response.payload.length(), bytes));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ struct invocation_request {
*/
std::chrono::time_point<std::chrono::system_clock> deadline;

/**
Copy link
Contributor

@maxday maxday Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hopefully at some point, we won't have a copy of aws-lambda-cpp in here, so let's keep that change here but I think you also need to modify the source of truth here: https://github.com/awslabs/aws-lambda-cpp? (Not blocking for this PR to be merged) I've also created an issue for us to track that change at some point: #541

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, one day, we could get rid of this. (Not blocking PR from my end either).

* Tenant ID of the current invocation.
*/
std::string tenant_id;

/**
* The number of milliseconds left before lambda terminates the current execution.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static constexpr auto CLIENT_CONTEXT_HEADER = "lambda-runtime-client-context";
static constexpr auto COGNITO_IDENTITY_HEADER = "lambda-runtime-cognito-identity";
static constexpr auto DEADLINE_MS_HEADER = "lambda-runtime-deadline-ms";
static constexpr auto FUNCTION_ARN_HEADER = "lambda-runtime-invoked-function-arn";
static constexpr auto TENANT_ID_HEADER = "lambda-runtime-aws-tenant-id";

enum Endpoints {
INIT,
Expand Down Expand Up @@ -301,6 +302,10 @@ runtime::next_outcome runtime::get_next()
req.payload.c_str(),
static_cast<int64_t>(req.get_time_remaining().count()));
}

if (resp.has_header(TENANT_ID_HEADER)) {
req.tenant_id = resp.get_header(TENANT_ID_HEADER);
}
return next_outcome(req);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public class LambdaContextTest {
private static final String INVOKED_FUNCTION_ARN = "invoked-function-arn";
private static final LambdaClientContext CLIENT_CONTEXT = new LambdaClientContext();
public static final int MEMORY_LIMIT = 128;
public static final String TENANT_ID = "tenant-id";

@Test
public void getRemainingTimeInMillis() {
Expand Down Expand Up @@ -54,6 +55,6 @@ public void getRemainingTimeInMillis_Deadline() throws InterruptedException {

private LambdaContext createContextWithDeadline(long deadlineTimeInMs) {
return new LambdaContext(MEMORY_LIMIT, deadlineTimeInMs, REQUEST_ID, LOG_GROUP_NAME, LOG_STREAM_NAME,
FUNCTION_NAME, IDENTITY, FUNCTION_VERSION, INVOKED_FUNCTION_ARN, CLIENT_CONTEXT);
FUNCTION_NAME, IDENTITY, FUNCTION_VERSION, INVOKED_FUNCTION_ARN, TENANT_ID, CLIENT_CONTEXT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ void testFormattingWithLambdaContext() {
null,
null,
"function-arn",
null,
null
);
assertFormatsString("test log", LogLevel.WARN, context);
}

@Test
void testFormattingWithTenantIdInLambdaContext() {
LambdaContext context = new LambdaContext(
0,
0,
"request-id",
null,
null,
"function-name",
null,
null,
"function-arn",
"tenant-id",
null
);
assertFormatsString("test log", LogLevel.WARN, context);
Expand All @@ -52,6 +71,7 @@ void assert_expected_log_message(StructuredLogMessage result, String message, Lo

if (context != null) {
assertEquals(context.getAwsRequestId(), result.AWSRequestId);
assertEquals(context.getTenantId(), result.tenantId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.ErrorRequest;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.InvocationRequest;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.StackElement;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.XRayErrorCause;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.XRayException;
Expand Down Expand Up @@ -312,27 +313,62 @@ public void restoreNextWrongStatusCodeTest() {
}

@Test
public void nextTest() {
public void nextWithoutTenantIdHeaderTest() {
try {
MockResponse mockResponse = new MockResponse();
mockResponse.setResponseCode(HTTP_ACCEPTED);
mockResponse.setHeader("lambda-runtime-aws-request-id", "1234567890");
mockResponse.setHeader("Content-Type", "application/json");
MockResponse mockResponse = buildMockResponseForNextInvocation();
mockWebServer.enqueue(mockResponse);

lambdaRuntimeApiClientImpl.nextInvocation();
RecordedRequest recordedRequest = mockWebServer.takeRequest();
HttpUrl actualUrl = recordedRequest.getRequestUrl();
String expectedUrl = "http://" + getHostnamePort() + "/2018-06-01/runtime/invocation/next";
assertEquals(expectedUrl, actualUrl.toString());
InvocationRequest invocationRequest = lambdaRuntimeApiClientImpl.nextInvocation();
verifyNextInvocationRequest();
assertNull(invocationRequest.getTenantId());
} catch(Exception e) {
fail();
}
}

@Test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great test! 💯

public void nextWithTenantIdHeaderTest() {
try {
MockResponse mockResponse = buildMockResponseForNextInvocation();
String expectedTenantId = "my-tenant-id";
mockResponse.setHeader("lambda-runtime-aws-tenant-id", expectedTenantId);
Copy link
Contributor

@maxday maxday Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be great to have another set of tests for:

  1. header not present
  2. header present and empty string
  3. header present and null value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Added these test cases.

mockWebServer.enqueue(mockResponse);

InvocationRequest invocationRequest = lambdaRuntimeApiClientImpl.nextInvocation();
verifyNextInvocationRequest();
assertEquals(expectedTenantId, invocationRequest.getTenantId());

String actualBody = recordedRequest.getBody().readUtf8();
assertEquals("", actualBody);
} catch(Exception e) {
fail();
}
}

@Test
public void nextWithEmptyTenantIdHeaderTest() {
try {
MockResponse mockResponse = buildMockResponseForNextInvocation();
mockResponse.setHeader("lambda-runtime-aws-tenant-id", "");
mockWebServer.enqueue(mockResponse);

InvocationRequest invocationRequest = lambdaRuntimeApiClientImpl.nextInvocation();
verifyNextInvocationRequest();
assertNull(invocationRequest.getTenantId());
} catch(Exception e) {
fail();
}
}

@Test
public void nextWithNullTenantIdHeaderTest() {
try {
MockResponse mockResponse = buildMockResponseForNextInvocation();
assertThrows(NullPointerException.class, () -> {
mockResponse.setHeader("lambda-runtime-aws-tenant-id", null);
});
} catch(Exception e) {
fail();
}
}

@Test
public void createUrlMalformedTest() {
Expand Down Expand Up @@ -376,6 +412,24 @@ public void lambdaReportErrorXRayHeaderTooLongTest() {
}
}

private MockResponse buildMockResponseForNextInvocation() {
MockResponse mockResponse = new MockResponse();
mockResponse.setResponseCode(HTTP_ACCEPTED);
mockResponse.setHeader("lambda-runtime-aws-request-id", "1234567890");
mockResponse.setHeader("Content-Type", "application/json");
return mockResponse;
}

private void verifyNextInvocationRequest() throws Exception {
RecordedRequest recordedRequest = mockWebServer.takeRequest();
HttpUrl actualUrl = recordedRequest.getRequestUrl();
String expectedUrl = "http://" + getHostnamePort() + "/2018-06-01/runtime/invocation/next";
assertEquals(expectedUrl, actualUrl.toString());

String actualBody = recordedRequest.getBody().readUtf8();
assertEquals("", actualBody);
}

private String getHostnamePort() {
return mockWebServer.getHostName() + ":" + mockWebServer.getPort();
}
Expand Down