Skip to content

Commit 56608cc

Browse files
Jithendar12burhan94
authored andcommitted
Abstract common OAuth handling and add OAuth support to Athena DataLake Gen2 Connector (awslabs#2932)
Co-authored-by: burhan94 <[email protected]>
1 parent 9d77d6b commit 56608cc

File tree

23 files changed

+1263
-189
lines changed

23 files changed

+1263
-189
lines changed

athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveJdbcConnectionFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
package com.amazonaws.athena.connectors.cloudera;
2222

23+
import com.amazonaws.athena.connector.credentials.CredentialsConstants;
2324
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
2425
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
2526
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
@@ -62,7 +63,7 @@ public Connection getConnection(final CredentialsProvider credentialsProvider)
6263
if (null != credentialsProvider) {
6364
Matcher secretMatcher = SECRET_NAME_PATTERN.matcher(databaseConnectionConfig.getJdbcConnectionString());
6465
final String secretReplacement = String.format("UID=%s;PWD=%s",
65-
credentialsProvider.getCredential().getUser(), credentialsProvider.getCredential().getPassword());
66+
credentialsProvider.getCredentialMap().get(CredentialsConstants.USER), credentialsProvider.getCredentialMap().get(CredentialsConstants.PASSWORD));
6667
derivedJdbcString = secretMatcher.replaceAll(Matcher.quoteReplacement(secretReplacement));
6768
}
6869
else {

athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaJdbcConnectionFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
package com.amazonaws.athena.connectors.cloudera;
2222

23+
import com.amazonaws.athena.connector.credentials.CredentialsConstants;
2324
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
2425
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
2526
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
@@ -63,7 +64,7 @@ public Connection getConnection(final CredentialsProvider credentialsProvider)
6364
if (null != credentialsProvider) {
6465
Matcher secretMatcher = SECRET_NAME_PATTERN.matcher(databaseConnectionConfig.getJdbcConnectionString());
6566
final String secretReplacement = String.format("UID=%s;PWD=%s",
66-
credentialsProvider.getCredential().getUser(), credentialsProvider.getCredential().getPassword());
67+
credentialsProvider.getCredentialMap().get(CredentialsConstants.USER), credentialsProvider.getCredentialMap().get(CredentialsConstants.PASSWORD));
6768
derivedJdbcString = secretMatcher.replaceAll(Matcher.quoteReplacement(secretReplacement));
6869
}
6970
else {

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MetadataHandler.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package com.amazonaws.athena.connectors.datalakegen2;
2121

2222
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
23+
import com.amazonaws.athena.connector.credentials.CredentialsProviderFactory;
2324
import com.amazonaws.athena.connector.lambda.QueryStatusChecker;
2425
import com.amazonaws.athena.connector.lambda.data.Block;
2526
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
@@ -54,7 +55,6 @@
5455
import org.apache.arrow.vector.types.Types;
5556
import org.apache.arrow.vector.types.pojo.ArrowType;
5657
import org.apache.arrow.vector.types.pojo.Schema;
57-
import org.apache.commons.lang3.StringUtils;
5858
import org.slf4j.Logger;
5959
import org.slf4j.LoggerFactory;
6060
import software.amazon.awssdk.services.athena.AthenaClient;
@@ -98,7 +98,7 @@ public DataLakeGen2MetadataHandler(java.util.Map<String, String> configOptions)
9898
public DataLakeGen2MetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map<String, String> configOptions)
9999
{
100100
this(databaseConnectionConfig,
101-
new DataLakeGen2JdbcConnectionFactory(databaseConnectionConfig, JDBC_PROPERTIES,
101+
new GenericJdbcConnectionFactory(databaseConnectionConfig, JDBC_PROPERTIES,
102102
new DatabaseConnectionInfo(DataLakeGen2Constants.DRIVER_CLASS, DataLakeGen2Constants.DEFAULT_PORT)),
103103
configOptions);
104104
}
@@ -289,11 +289,10 @@ protected Schema getSchema(Connection jdbcConnection, TableName tableName, Schem
289289
@Override
290290
protected CredentialsProvider getCredentialProvider()
291291
{
292-
final String secretName = getDatabaseConnectionConfig().getSecret();
293-
if (StringUtils.isNotBlank(secretName)) {
294-
return new DataLakeGen2CredentialsProvider(secretName);
295-
}
296-
297-
return null;
292+
return CredentialsProviderFactory.createCredentialProvider(
293+
getDatabaseConnectionConfig().getSecret(),
294+
getCachableSecretsManager(),
295+
new DataLakeGen2OAuthCredentialsProvider()
296+
);
298297
}
299298
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*-
2+
* #%L
3+
* athena-datalakegen2
4+
* %%
5+
* Copyright (C) 2019 - 2025 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connectors.datalakegen2;
21+
22+
import com.amazonaws.athena.connector.credentials.CredentialsConstants;
23+
import com.amazonaws.athena.connector.credentials.OAuthCredentialsProvider;
24+
import com.google.common.annotations.VisibleForTesting;
25+
26+
import java.net.URI;
27+
import java.net.http.HttpClient;
28+
import java.net.http.HttpRequest;
29+
import java.util.Map;
30+
31+
/**
32+
* OAuth credentials provider for Azure Data Lake Gen2.
33+
*/
34+
public class DataLakeGen2OAuthCredentialsProvider extends OAuthCredentialsProvider
35+
{
36+
private static final String TOKEN_ENDPOINT_FORMAT = "https://login.microsoftonline.com/%s/oauth2/v2.0/token";
37+
private static final String SCOPE = "https://sql.azuresynapse.net/.default";
38+
private static final String TENANT_ID = "tenant_id";
39+
40+
public DataLakeGen2OAuthCredentialsProvider()
41+
{
42+
super();
43+
}
44+
45+
@VisibleForTesting
46+
protected DataLakeGen2OAuthCredentialsProvider(HttpClient httpClient)
47+
{
48+
super(httpClient);
49+
}
50+
51+
@Override
52+
protected boolean isOAuthConfigured(Map<String, String> secretMap)
53+
{
54+
return secretMap.containsKey(CredentialsConstants.CLIENT_ID) &&
55+
!secretMap.get(CredentialsConstants.CLIENT_ID).isEmpty() &&
56+
secretMap.containsKey(CredentialsConstants.CLIENT_SECRET) &&
57+
!secretMap.get(CredentialsConstants.CLIENT_SECRET).isEmpty() &&
58+
secretMap.containsKey(TENANT_ID) &&
59+
!secretMap.get(TENANT_ID).isEmpty();
60+
}
61+
62+
@Override
63+
protected HttpRequest buildTokenRequest(Map<String, String> secretMap)
64+
{
65+
String clientId = secretMap.get(CredentialsConstants.CLIENT_ID);
66+
String clientSecret = secretMap.get(CredentialsConstants.CLIENT_SECRET);
67+
String tenantId = secretMap.get(TENANT_ID);
68+
String tokenEndpoint = String.format(TOKEN_ENDPOINT_FORMAT, tenantId);
69+
70+
String formData = String.format(
71+
"grant_type=client_credentials&scope=%s&client_id=%s&client_secret=%s",
72+
SCOPE, clientId, clientSecret);
73+
74+
return HttpRequest.newBuilder()
75+
.uri(URI.create(tokenEndpoint))
76+
.header("Content-Type", "application/x-www-form-urlencoded")
77+
.POST(HttpRequest.BodyPublishers.ofString(formData))
78+
.build();
79+
}
80+
}

athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,19 @@
1919
*/
2020
package com.amazonaws.athena.connectors.datalakegen2;
2121
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
22+
import com.amazonaws.athena.connector.credentials.CredentialsProviderFactory;
2223
import com.amazonaws.athena.connector.lambda.domain.Split;
2324
import com.amazonaws.athena.connector.lambda.domain.TableName;
2425
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
2526
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
2627
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
28+
import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory;
2729
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
2830
import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil;
2931
import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler;
3032
import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder;
3133
import com.google.common.annotations.VisibleForTesting;
3234
import org.apache.arrow.vector.types.pojo.Schema;
33-
import org.apache.commons.lang3.StringUtils;
3435
import org.apache.commons.lang3.Validate;
3536
import software.amazon.awssdk.services.athena.AthenaClient;
3637
import software.amazon.awssdk.services.s3.S3Client;
@@ -53,7 +54,7 @@ public DataLakeGen2RecordHandler(java.util.Map<String, String> configOptions)
5354
public DataLakeGen2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map<String, String> configOptions)
5455
{
5556
this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(),
56-
new DataLakeGen2JdbcConnectionFactory(databaseConnectionConfig, DataLakeGen2MetadataHandler.JDBC_PROPERTIES,
57+
new GenericJdbcConnectionFactory(databaseConnectionConfig, DataLakeGen2MetadataHandler.JDBC_PROPERTIES,
5758
new DatabaseConnectionInfo(DataLakeGen2Constants.DRIVER_CLASS, DataLakeGen2Constants.DEFAULT_PORT)), new DataLakeGen2QueryStringBuilder(QUOTE_CHARACTER, new DataLakeGen2FederationExpressionParser(QUOTE_CHARACTER)), configOptions);
5859
}
5960
@VisibleForTesting
@@ -81,11 +82,10 @@ public PreparedStatement buildSplitSql(Connection jdbcConnection, String catalog
8182
@Override
8283
protected CredentialsProvider getCredentialProvider()
8384
{
84-
final String secretName = getDatabaseConnectionConfig().getSecret();
85-
if (StringUtils.isNotBlank(secretName)) {
86-
return new DataLakeGen2CredentialsProvider(secretName);
87-
}
88-
89-
return null;
85+
return CredentialsProviderFactory.createCredentialProvider(
86+
getDatabaseConnectionConfig().getSecret(),
87+
getCachableSecretsManager(),
88+
new DataLakeGen2OAuthCredentialsProvider()
89+
);
9090
}
9191
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*-
2+
* #%L
3+
* athena-datalakegen2
4+
* %%
5+
* Copyright (C) 2019 - 2025 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connectors.datalakegen2;
21+
22+
import com.amazonaws.athena.connector.credentials.OAuthAccessTokenCredentials;
23+
import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager;
24+
import com.fasterxml.jackson.databind.ObjectMapper;
25+
import org.junit.Before;
26+
import org.junit.Test;
27+
import org.mockito.Mock;
28+
import org.mockito.MockitoAnnotations;
29+
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
30+
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest;
31+
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse;
32+
33+
import java.io.IOException;
34+
import java.net.http.HttpClient;
35+
import java.net.http.HttpResponse;
36+
import java.util.HashMap;
37+
import java.util.Map;
38+
39+
import static com.amazonaws.athena.connector.credentials.CredentialsConstants.ACCESS_TOKEN;
40+
import static com.amazonaws.athena.connector.credentials.CredentialsConstants.CLIENT_ID;
41+
import static com.amazonaws.athena.connector.credentials.CredentialsConstants.CLIENT_SECRET;
42+
import static com.amazonaws.athena.connector.credentials.CredentialsConstants.EXPIRES_IN;
43+
import static org.junit.Assert.assertEquals;
44+
import static org.junit.Assert.assertFalse;
45+
import static org.junit.Assert.assertTrue;
46+
import static org.mockito.ArgumentMatchers.any;
47+
import static org.mockito.Mockito.mock;
48+
import static org.mockito.Mockito.when;
49+
50+
public class DataLakeGen2OAuthCredentialsProviderTest
51+
{
52+
protected static final String SECRET_NAME = "test-secret";
53+
protected static final String TEST_CLIENT_ID = "test-client-id";
54+
protected static final String TEST_CLIENT_SECRET = "test-client-secret";
55+
protected static final String TENANT_ID = "tenant_id";
56+
protected static final String TEST_TENANT_ID = "test-tenant-id";
57+
protected static final String TEST_ACCESS_TOKEN = "test-access-token";
58+
59+
@Mock
60+
private SecretsManagerClient secretsManagerClient;
61+
62+
@Mock
63+
private HttpClient httpClient;
64+
65+
private DataLakeGen2OAuthCredentialsProvider credentialsProvider;
66+
private ObjectMapper objectMapper;
67+
68+
@Before
69+
public void setup()
70+
{
71+
MockitoAnnotations.openMocks(this);
72+
CachableSecretsManager cachableSecretsManager = new CachableSecretsManager(secretsManagerClient);
73+
credentialsProvider = new DataLakeGen2OAuthCredentialsProvider(httpClient);
74+
credentialsProvider.initialize(SECRET_NAME, new HashMap<>(), cachableSecretsManager);
75+
objectMapper = new ObjectMapper();
76+
}
77+
78+
@Test
79+
public void testIsOAuthConfigured_WithValidConfig()
80+
{
81+
Map<String, String> secretMap = new HashMap<>();
82+
secretMap.put(CLIENT_ID, TEST_CLIENT_ID);
83+
secretMap.put(CLIENT_SECRET, TEST_CLIENT_SECRET);
84+
secretMap.put(TENANT_ID, TEST_TENANT_ID);
85+
86+
assertTrue(credentialsProvider.isOAuthConfigured(secretMap));
87+
}
88+
89+
@Test
90+
public void testIsOAuthConfigured_WithMissingConfig()
91+
{
92+
Map<String, String> secretMap = new HashMap<>();
93+
secretMap.put(CLIENT_ID, TEST_CLIENT_ID);
94+
// Missing client_secret and tenant_id
95+
96+
assertFalse(credentialsProvider.isOAuthConfigured(secretMap));
97+
}
98+
99+
@Test
100+
public void testBuildTokenRequest()
101+
{
102+
Map<String, String> secretMap = new HashMap<>();
103+
secretMap.put(CLIENT_ID, TEST_CLIENT_ID);
104+
secretMap.put(CLIENT_SECRET, TEST_CLIENT_SECRET);
105+
secretMap.put(TENANT_ID, TEST_TENANT_ID);
106+
107+
var request = credentialsProvider.buildTokenRequest(secretMap);
108+
109+
assertEquals("POST", request.method());
110+
assertEquals("application/x-www-form-urlencoded", request.headers().firstValue("Content-Type").get());
111+
assertTrue(request.uri().toString().contains(TEST_TENANT_ID));
112+
}
113+
114+
@Test
115+
public void testGetCredential_WithValidOAuthConfig() throws IOException, InterruptedException
116+
{
117+
// Setup secret with OAuth config
118+
Map<String, String> secretMap = new HashMap<>();
119+
secretMap.put(CLIENT_ID, TEST_CLIENT_ID);
120+
secretMap.put(CLIENT_SECRET, TEST_CLIENT_SECRET);
121+
secretMap.put(TENANT_ID, TEST_TENANT_ID);
122+
String secretString = objectMapper.writeValueAsString(secretMap);
123+
124+
when(secretsManagerClient.getSecretValue(any(GetSecretValueRequest.class)))
125+
.thenReturn(GetSecretValueResponse.builder().secretString(secretString).build());
126+
127+
// Setup mock HTTP response with token
128+
Map<String, String> tokenResponse = new HashMap<>();
129+
tokenResponse.put(ACCESS_TOKEN, TEST_ACCESS_TOKEN);
130+
tokenResponse.put(EXPIRES_IN, "3600");
131+
String responseBody = objectMapper.writeValueAsString(tokenResponse);
132+
133+
@SuppressWarnings("unchecked")
134+
HttpResponse<Object> typedResponse = mock(HttpResponse.class);
135+
when(typedResponse.statusCode()).thenReturn(200);
136+
when(typedResponse.body()).thenReturn(responseBody);
137+
when(httpClient.send(any(), any())).thenReturn(typedResponse);
138+
139+
// Get and verify credential
140+
var credential = credentialsProvider.getCredential();
141+
assertTrue(credential instanceof OAuthAccessTokenCredentials);
142+
assertEquals(TEST_ACCESS_TOKEN, ((OAuthAccessTokenCredentials) credential).getAccessToken());
143+
}
144+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*-
2+
* #%L
3+
* athena-federation-sdk
4+
* %%
5+
* Copyright (C) 2019 - 2025 Amazon Web Services
6+
* %%
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* #L%
19+
*/
20+
package com.amazonaws.athena.connector.credentials;
21+
22+
import java.util.Map;
23+
24+
/**
25+
* Represents a set of credentials required to authenticate and connect to a database.
26+
* Implementations may provide credentials in different forms, such as username/password or OAuth tokens.
27+
*/
28+
public interface Credentials
29+
{
30+
/**
31+
* Gets the credential properties for database authentication.
32+
* Keys are property names (e.g., "username", "password", "accesToken"),
33+
* and values are the associated property values.
34+
*
35+
* @return a map of credential property names to values
36+
*/
37+
Map<String, String> getProperties();
38+
}

0 commit comments

Comments
 (0)