Skip to content

Commit c10151e

Browse files
committed
Extract getCredentialProvider logic into utility class and add unit tests
1 parent 8282e75 commit c10151e

File tree

6 files changed

+368
-64
lines changed

6 files changed

+368
-64
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*-
2+
* #%L
3+
* athena-datalakegen2
4+
* %%
5+
* Copyright (C) 2019 - 2025 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connectors.datalakegen2;
21+
22+
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
23+
import com.amazonaws.athena.connector.credentials.DefaultCredentialsProvider;
24+
import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException;
25+
import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager;
26+
import com.fasterxml.jackson.databind.ObjectMapper;
27+
import org.apache.commons.lang3.StringUtils;
28+
import software.amazon.awssdk.services.glue.model.ErrorDetails;
29+
import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode;
30+
31+
import java.io.IOException;
32+
import java.util.Map;
33+
34+
/**
35+
* Utility class for handling credential provider functionality.
36+
*/
37+
public final class DataLakeGen2CredentialProviderUtils
38+
{
39+
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
40+
41+
private DataLakeGen2CredentialProviderUtils()
42+
{
43+
}
44+
45+
/**
46+
* Gets the credentials provider based on the secret configuration.
47+
* If OAuth is configured, returns DataLakeGen2 OAuth credentials provider.
48+
* Otherwise, falls back to default username/password credentials.
49+
*/
50+
public static CredentialsProvider getCredentialProvider(String secretName, CachableSecretsManager secretsManager)
51+
{
52+
if (StringUtils.isNotBlank(secretName)) {
53+
try {
54+
String secretString = secretsManager.getSecret(secretName);
55+
Map<String, String> secretMap = OBJECT_MAPPER.readValue(secretString, Map.class);
56+
57+
// Check if OAuth is configured
58+
if (DataLakeGen2OAuthCredentialsProvider.isOAuthConfigured(secretMap)) {
59+
return new DataLakeGen2OAuthCredentialsProvider(secretName, secretMap, secretsManager);
60+
}
61+
62+
// Fall back to default credentials if OAuth is not configured
63+
return new DefaultCredentialsProvider(secretString);
64+
}
65+
catch (IOException ioException) {
66+
throw new AthenaConnectorException("Could not deserialize RDS credentials into HashMap: ",
67+
ErrorDetails.builder().errorCode(FederationSourceErrorCode.INTERNAL_SERVICE_EXCEPTION.toString()).errorMessage(ioException.getMessage()).build());
68+
}
69+
}
70+
71+
return null;
72+
}
73+
}

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandler.java

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
package com.amazonaws.athena.connectors.datalakegen2;
2121

2222
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
23-
import com.amazonaws.athena.connector.credentials.DefaultCredentialsProvider;
2423
import com.amazonaws.athena.connector.lambda.QueryStatusChecker;
2524
import com.amazonaws.athena.connector.lambda.data.Block;
2625
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
@@ -31,7 +30,6 @@
3130
import com.amazonaws.athena.connector.lambda.domain.Split;
3231
import com.amazonaws.athena.connector.lambda.domain.TableName;
3332
import com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions;
34-
import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException;
3533
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
3634
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
3735
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
@@ -51,22 +49,17 @@
5149
import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter;
5250
import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler;
5351
import com.amazonaws.athena.connectors.jdbc.resolver.JDBCCaseResolver;
54-
import com.fasterxml.jackson.databind.ObjectMapper;
5552
import com.google.common.annotations.VisibleForTesting;
5653
import com.google.common.collect.ImmutableMap;
5754
import com.google.common.collect.ImmutableSet;
5855
import org.apache.arrow.vector.types.Types;
5956
import org.apache.arrow.vector.types.pojo.ArrowType;
6057
import org.apache.arrow.vector.types.pojo.Schema;
61-
import org.apache.commons.lang3.StringUtils;
6258
import org.slf4j.Logger;
6359
import org.slf4j.LoggerFactory;
6460
import software.amazon.awssdk.services.athena.AthenaClient;
65-
import software.amazon.awssdk.services.glue.model.ErrorDetails;
66-
import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode;
6761
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
6862

69-
import java.io.IOException;
7063
import java.sql.Connection;
7164
import java.sql.PreparedStatement;
7265
import java.sql.ResultSet;
@@ -85,7 +78,6 @@
8578
public class DataLakeGen2MetadataHandler extends JdbcMetadataHandler
8679
{
8780
private static final Logger LOGGER = LoggerFactory.getLogger(DataLakeGen2MetadataHandler.class);
88-
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
8981

9082
static final Map<String, String> JDBC_PROPERTIES = ImmutableMap.of("databaseTerm", "SCHEMA");
9183
static final String PARTITION_NUMBER = "partition_number";
@@ -297,26 +289,9 @@ protected Schema getSchema(Connection jdbcConnection, TableName tableName, Schem
297289
@Override
298290
protected CredentialsProvider getCredentialProvider()
299291
{
300-
final String secretName = getDatabaseConnectionConfig().getSecret();
301-
if (StringUtils.isNotBlank(secretName)) {
302-
try {
303-
String secretString = getCachableSecretsManager().getSecret(secretName);
304-
Map<String, String> secretMap = OBJECT_MAPPER.readValue(secretString, Map.class);
305-
306-
// Check if OAuth is configured
307-
if (DataLakeGen2OAuthCredentialsProvider.isOAuthConfigured(secretMap)) {
308-
return new DataLakeGen2OAuthCredentialsProvider(secretName, secretMap, getCachableSecretsManager());
309-
}
310-
311-
// Fall back to default credentials if OAuth is not configured
312-
return new DefaultCredentialsProvider(secretString);
313-
}
314-
catch (IOException ioException) {
315-
throw new AthenaConnectorException("Could not deserialize RDS credentials into HashMap: ",
316-
ErrorDetails.builder().errorCode(FederationSourceErrorCode.INTERNAL_SERVICE_EXCEPTION.toString()).errorMessage(ioException.getMessage()).build());
317-
}
318-
}
319-
320-
return null;
292+
return DataLakeGen2CredentialProviderUtils.getCredentialProvider(
293+
getDatabaseConnectionConfig().getSecret(),
294+
getCachableSecretsManager()
295+
);
321296
}
322297
}

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,41 +19,32 @@
1919
*/
2020
package com.amazonaws.athena.connectors.datalakegen2;
2121
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
22-
import com.amazonaws.athena.connector.credentials.DefaultCredentialsProvider;
2322
import com.amazonaws.athena.connector.lambda.domain.Split;
2423
import com.amazonaws.athena.connector.lambda.domain.TableName;
2524
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
26-
import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException;
2725
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
2826
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
2927
import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory;
3028
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
3129
import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil;
3230
import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler;
3331
import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder;
34-
import com.fasterxml.jackson.databind.ObjectMapper;
3532
import com.google.common.annotations.VisibleForTesting;
3633
import org.apache.arrow.vector.types.pojo.Schema;
37-
import org.apache.commons.lang3.StringUtils;
3834
import org.apache.commons.lang3.Validate;
3935
import software.amazon.awssdk.services.athena.AthenaClient;
40-
import software.amazon.awssdk.services.glue.model.ErrorDetails;
41-
import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode;
4236
import software.amazon.awssdk.services.s3.S3Client;
4337
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
4438

45-
import java.io.IOException;
4639
import java.sql.Connection;
4740
import java.sql.PreparedStatement;
4841
import java.sql.SQLException;
49-
import java.util.Map;
5042

5143
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2Constants.QUOTE_CHARACTER;
5244

5345
public class DataLakeGen2RecordHandler extends JdbcRecordHandler
5446
{
5547
private static final int FETCH_SIZE = 1000;
56-
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
5748
private final JdbcSplitQueryBuilder jdbcSplitQueryBuilder;
5849
public DataLakeGen2RecordHandler(java.util.Map<String, String> configOptions)
5950
{
@@ -90,25 +81,9 @@ public PreparedStatement buildSplitSql(Connection jdbcConnection, String catalog
9081
@Override
9182
protected CredentialsProvider getCredentialProvider()
9283
{
93-
final String secretName = getDatabaseConnectionConfig().getSecret();
94-
if (StringUtils.isNotBlank(secretName)) {
95-
try {
96-
String secretString = getCachableSecretsManager().getSecret(secretName);
97-
Map<String, String> secretMap = OBJECT_MAPPER.readValue(secretString, Map.class);
98-
99-
// Check if OAuth is configured
100-
if (DataLakeGen2OAuthCredentialsProvider.isOAuthConfigured(secretMap)) {
101-
return new DataLakeGen2OAuthCredentialsProvider(secretName, secretMap, getCachableSecretsManager());
102-
}
103-
104-
// Fall back to default credentials if OAuth is not configured
105-
return new DefaultCredentialsProvider(secretString);
106-
}
107-
catch (IOException ioException) {
108-
throw new AthenaConnectorException("Could not deserialize RDS credentials into HashMap: ",
109-
ErrorDetails.builder().errorCode(FederationSourceErrorCode.INTERNAL_SERVICE_EXCEPTION.toString()).errorMessage(ioException.getMessage()).build());
110-
}
111-
}
112-
return null;
84+
return DataLakeGen2CredentialProviderUtils.getCredentialProvider(
85+
getDatabaseConnectionConfig().getSecret(),
86+
getCachableSecretsManager()
87+
);
11388
}
11489
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*-
2+
* #%L
3+
* athena-datalakegen2
4+
* %%
5+
* Copyright (C) 2019 - 2025 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connectors.datalakegen2;
21+
22+
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
23+
import com.amazonaws.athena.connector.credentials.DefaultCredentialsProvider;
24+
import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException;
25+
import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager;
26+
import com.fasterxml.jackson.core.JsonProcessingException;
27+
import com.fasterxml.jackson.databind.ObjectMapper;
28+
import org.junit.Before;
29+
import org.junit.Test;
30+
import org.junit.runner.RunWith;
31+
import org.mockito.Mock;
32+
import org.mockito.junit.MockitoJUnitRunner;
33+
import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode;
34+
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
35+
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest;
36+
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse;
37+
38+
import java.util.Map;
39+
40+
import static com.amazonaws.athena.connector.credentials.CredentialsConstants.CLIENT_ID;
41+
import static com.amazonaws.athena.connector.credentials.CredentialsConstants.CLIENT_SECRET;
42+
import static com.amazonaws.athena.connector.credentials.CredentialsConstants.PASSWORD;
43+
import static com.amazonaws.athena.connector.credentials.CredentialsConstants.USERNAME;
44+
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2OAuthCredentialsProviderTest.SECRET_NAME;
45+
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2OAuthCredentialsProviderTest.TENANT_ID;
46+
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2OAuthCredentialsProviderTest.TEST_CLIENT_ID;
47+
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2OAuthCredentialsProviderTest.TEST_CLIENT_SECRET;
48+
import static com.amazonaws.athena.connectors.datalakegen2.DataLakeGen2OAuthCredentialsProviderTest.TEST_TENANT_ID;
49+
import static org.junit.Assert.assertEquals;
50+
import static org.junit.Assert.assertNull;
51+
import static org.junit.Assert.assertTrue;
52+
import static org.junit.Assert.fail;
53+
import static org.mockito.ArgumentMatchers.any;
54+
import static org.mockito.Mockito.when;
55+
56+
@RunWith(MockitoJUnitRunner.class)
57+
public class DataLakeGen2CredentialProviderUtilsTest
58+
{
59+
@Mock
60+
private SecretsManagerClient secretsManager;
61+
62+
private CachableSecretsManager cachableSecretsManager;
63+
64+
@Before
65+
public void setup()
66+
{
67+
cachableSecretsManager = new CachableSecretsManager(secretsManager);
68+
}
69+
70+
@Test
71+
public void testGetCredentialProvider_whenOAuthConfigured() throws JsonProcessingException
72+
{
73+
// Mock OAuth secret response
74+
String oauthSecret = new ObjectMapper().writeValueAsString(Map.of(
75+
CLIENT_ID, TEST_CLIENT_ID,
76+
CLIENT_SECRET, TEST_CLIENT_SECRET,
77+
TENANT_ID, TEST_TENANT_ID
78+
));
79+
when(secretsManager.getSecretValue(any(GetSecretValueRequest.class))).thenReturn(
80+
GetSecretValueResponse.builder().secretString(oauthSecret).build()
81+
);
82+
83+
CredentialsProvider provider = DataLakeGen2CredentialProviderUtils.getCredentialProvider(SECRET_NAME, cachableSecretsManager);
84+
assertTrue(provider instanceof DataLakeGen2OAuthCredentialsProvider);
85+
}
86+
87+
@Test
88+
public void testGetCredentialProvider_whenUsernamePasswordConfigured() throws JsonProcessingException
89+
{
90+
// Mock username/password secret response
91+
String standardSecret = new ObjectMapper().writeValueAsString(Map.of(
92+
USERNAME, "test-user",
93+
PASSWORD, "test-password"
94+
));
95+
when(secretsManager.getSecretValue(any(GetSecretValueRequest.class))).thenReturn(
96+
GetSecretValueResponse.builder().secretString(standardSecret).build()
97+
);
98+
99+
CredentialsProvider provider = DataLakeGen2CredentialProviderUtils.getCredentialProvider(SECRET_NAME, cachableSecretsManager);
100+
assertTrue(provider instanceof DefaultCredentialsProvider);
101+
}
102+
103+
@Test
104+
public void testGetCredentialProvider_whenNoSecret()
105+
{
106+
CredentialsProvider provider = DataLakeGen2CredentialProviderUtils.getCredentialProvider("", cachableSecretsManager);
107+
assertNull(provider);
108+
}
109+
110+
@Test
111+
public void testGetCredentialProvider_whenInvalidJson_throwsException()
112+
{
113+
// Mock invalid JSON response
114+
when(secretsManager.getSecretValue(any(GetSecretValueRequest.class))).thenReturn(
115+
GetSecretValueResponse.builder().secretString("invalid-json{").build()
116+
);
117+
118+
try {
119+
DataLakeGen2CredentialProviderUtils.getCredentialProvider(SECRET_NAME, cachableSecretsManager);
120+
fail("Expected AthenaConnectorException");
121+
}
122+
catch (AthenaConnectorException e) {
123+
assertEquals(FederationSourceErrorCode.INTERNAL_SERVICE_EXCEPTION.toString(),
124+
e.getErrorDetails().errorCode());
125+
assertTrue(e.getMessage().contains("Could not deserialize RDS credentials into HashMap"));
126+
}
127+
}
128+
}

athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2OAuthCredentialsProviderTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@
4949

5050
public class DataLakeGen2OAuthCredentialsProviderTest
5151
{
52-
private static final String SECRET_NAME = "test-secret";
53-
private static final String TEST_CLIENT_ID = "test-client-id";
54-
private static final String TEST_CLIENT_SECRET = "test-client-secret";
55-
private static final String TENANT_ID = "tenant_id";
56-
private static final String TEST_TENANT_ID = "test-tenant-id";
57-
private static final String TEST_ACCESS_TOKEN = "test-access-token";
52+
protected static final String SECRET_NAME = "test-secret";
53+
protected static final String TEST_CLIENT_ID = "test-client-id";
54+
protected static final String TEST_CLIENT_SECRET = "test-client-secret";
55+
protected static final String TENANT_ID = "tenant_id";
56+
protected static final String TEST_TENANT_ID = "test-tenant-id";
57+
protected static final String TEST_ACCESS_TOKEN = "test-access-token";
5858

5959
@Mock
6060
private SecretsManagerClient secretsManagerClient;

0 commit comments

Comments
 (0)