Skip to content

Commit 65ccb34

Browse files
Trianz-Akshayozakidaichngpe
authored andcommitted
Implement Key-Pair Authentication for Athena Snowflake Connector (awslabs#2857)
Co-authored-by: Dai Ozaki <[email protected]> Co-authored-by: chngpe <[email protected]>
1 parent 92d6a31 commit 65ccb34

File tree

11 files changed

+1270
-56
lines changed

11 files changed

+1270
-56
lines changed

athena-snowflake/pom.xml

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,20 @@
3232
<artifactId>snowflake-jdbc</artifactId>
3333
<version>3.25.1</version>
3434
</dependency>
35-
<!-- https://mvnrepository.com/artifact/org.json/json -->
3635
<dependency>
37-
<groupId>org.json</groupId>
38-
<artifactId>json</artifactId>
39-
<version>20250517</version>
36+
<groupId>org.bouncycastle</groupId>
37+
<artifactId>bcprov-jdk18on</artifactId>
38+
<version>1.78.1</version>
39+
</dependency>
40+
<dependency>
41+
<groupId>org.bouncycastle</groupId>
42+
<artifactId>bcpkix-jdk18on</artifactId>
43+
<version>1.78.1</version>
44+
</dependency>
45+
<dependency>
46+
<groupId>org.bouncycastle</groupId>
47+
<artifactId>bcutil-jdk18on</artifactId>
48+
<version>1.78.1</version>
4049
</dependency>
4150
<!-- https://mvnrepository.com/artifact/software.amazon.awssdk/rds -->
4251
<dependency>

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeConstants.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,36 @@ public final class SnowflakeConstants
3535
*/
3636
public static final int SINGLE_SPLIT_LIMIT_COUNT = 10000;
3737
public static final String SNOWFLAKE_QUOTE_CHARACTER = "\"";
38+
39+
/** Configuration key for specifying the authentication method */
40+
public static final String AUTHENTICATOR = "authenticator";
41+
42+
/**
43+
* OAuth 2.0 Authentication Constants
44+
* These constants are used for configuring OAuth-based authentication with Snowflake.
45+
*/
3846
public static final String AUTH_CODE = "auth_code";
3947
public static final String CLIENT_ID = "client_id";
4048
public static final String TOKEN_URL = "token_url";
4149
public static final String REDIRECT_URI = "redirect_uri";
4250
public static final String CLIENT_SECRET = "client_secret";
4351

52+
/**
53+
* Key-Pair Authentication Constants
54+
* These constants are used for configuring public/private key pair authentication with Snowflake.
55+
*/
56+
public static final String SF_USER = "sfUser";
57+
public static final String PEM_PRIVATE_KEY = "pem_private_key";
58+
public static final String PEM_PRIVATE_KEY_PASSPHRASE = "pem_private_key_passphrase";
59+
public static final String PRIVATE_KEY = "privateKey";
60+
61+
/**
62+
* Password Authentication Constants
63+
* These constants are used for traditional username/password authentication with Snowflake.
64+
*/
65+
public static final String USERNAME = "username";
66+
public static final String PASSWORD = "password";
67+
public static final String USER = "user";
68+
4469
private SnowflakeConstants() {}
4570
}

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeCredentialsProvider.java

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
2323
import com.amazonaws.athena.connector.credentials.DefaultCredentials;
2424
import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager;
25+
import com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthType;
26+
import com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthUtils;
2527
import com.fasterxml.jackson.databind.ObjectMapper;
28+
import com.fasterxml.jackson.databind.node.ObjectNode;
2629
import com.google.common.annotations.VisibleForTesting;
27-
import org.json.JSONObject;
2830
import org.slf4j.Logger;
2931
import org.slf4j.LoggerFactory;
3032
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
@@ -43,10 +45,14 @@
4345
import java.util.HashMap;
4446
import java.util.Map;
4547

48+
import static com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthType.OAUTH;
49+
import static com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthUtils.getUsername;
50+
import static com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthUtils.validateCredentials;
51+
4652
/**
47-
* Snowflake OAuth credentials provider that manages OAuth token lifecycle.
48-
* This provider handles token refresh, expiration, and provides credential properties
49-
* for Snowflake OAuth connections.
53+
* Snowflake credentials provider that manages multiple authentication methods.
54+
* This provider handles OAuth token lifecycle, key-pair authentication, and password authentication.
55+
* Authentication method is automatically determined based on the secret contents.
5056
*/
5157
public class SnowflakeCredentialsProvider implements CredentialsProvider
5258
{
@@ -56,10 +62,6 @@ public class SnowflakeCredentialsProvider implements CredentialsProvider
5662
public static final String FETCHED_AT = "fetched_at";
5763
public static final String REFRESH_TOKEN = "refresh_token";
5864
public static final String EXPIRES_IN = "expires_in";
59-
public static final String USERNAME = "username";
60-
public static final String PASSWORD = "password";
61-
public static final String USER = "user";
62-
6365
private final String oauthSecretName;
6466
private final CachableSecretsManager secretsManager;
6567
private final ObjectMapper objectMapper;
@@ -82,8 +84,8 @@ public DefaultCredentials getCredential()
8284
{
8385
Map<String, String> credentialMap = getCredentialMap();
8486
return new DefaultCredentials(
85-
credentialMap.get(USER),
86-
credentialMap.get(PASSWORD)
87+
credentialMap.get(SnowflakeConstants.USER),
88+
credentialMap.get(SnowflakeConstants.PASSWORD)
8789
);
8890
}
8991

@@ -92,32 +94,71 @@ public Map<String, String> getCredentialMap()
9294
{
9395
try {
9496
String secretString = secretsManager.getSecret(oauthSecretName);
95-
Map<String, String> oauthConfig = objectMapper.readValue(secretString, Map.class);
97+
Map<String, String> secretMap = objectMapper.readValue(secretString, Map.class);
9698

97-
if (oauthConfig.containsKey(SnowflakeConstants.AUTH_CODE) && !oauthConfig.get(SnowflakeConstants.AUTH_CODE).isEmpty()) {
98-
// OAuth flow
99-
String accessToken = fetchAccessTokenFromSecret(oauthConfig);
100-
101-
Map<String, String> credentialMap = new HashMap<>();
102-
credentialMap.put(USER, oauthConfig.get(USERNAME));
103-
credentialMap.put(PASSWORD, accessToken);
104-
credentialMap.put("authenticator", "oauth");
105-
106-
return credentialMap;
107-
}
108-
else {
109-
// Fallback to standard credentials
110-
return Map.of(
111-
USER, oauthConfig.get(USERNAME),
112-
PASSWORD, oauthConfig.get(PASSWORD)
113-
);
99+
// Determine authentication type based on secret contents
100+
SnowflakeAuthType authType = SnowflakeAuthUtils.determineAuthType(secretMap);
101+
// Validate credentials once after determining auth type
102+
validateCredentials(secretMap, authType);
103+
switch (authType) {
104+
case SNOWFLAKE_JWT:
105+
// Key-pair authentication
106+
return handleKeyPairAuthentication(secretMap);
107+
case OAUTH:
108+
// OAuth flow
109+
return handleOAuthAuthentication(secretMap);
110+
case SNOWFLAKE:
111+
default:
112+
// Password authentication (backward compatible)
113+
return handlePasswordAuthentication(secretMap);
114114
}
115115
}
116116
catch (Exception ex) {
117117
throw new RuntimeException("Error retrieving Snowflake credentials: " + ex.getMessage(), ex);
118118
}
119119
}
120120

121+
/**
122+
* Handles key-pair authentication.
123+
*/
124+
private Map<String, String> handleKeyPairAuthentication(Map<String, String> oauthConfig)
125+
{
126+
Map<String, String> credentialMap = new HashMap<>();
127+
String username = getUsername(oauthConfig);
128+
credentialMap.put(USER, username);
129+
credentialMap.put(SnowflakeConstants.PEM_PRIVATE_KEY, oauthConfig.get(SnowflakeConstants.PEM_PRIVATE_KEY));
130+
credentialMap.put(SnowflakeConstants.PEM_PRIVATE_KEY_PASSPHRASE, oauthConfig.get(SnowflakeConstants.PEM_PRIVATE_KEY_PASSPHRASE));
131+
LOGGER.debug("Using key-pair authentication for user: {}", username);
132+
return credentialMap;
133+
}
134+
135+
/**
136+
* Handles OAuth authentication.
137+
*/
138+
private Map<String, String> handleOAuthAuthentication(Map<String, String> oauthConfig) throws Exception
139+
{
140+
String accessToken = fetchAccessTokenFromSecret(oauthConfig);
141+
Map<String, String> credentialMap = new HashMap<>();
142+
String username = getUsername(oauthConfig);
143+
credentialMap.put(SnowflakeConstants.USER, username);
144+
credentialMap.put(SnowflakeConstants.PASSWORD, accessToken);
145+
credentialMap.put(SnowflakeConstants.AUTHENTICATOR, OAUTH.getValue());
146+
LOGGER.debug("Using OAuth authentication for user: {}", username);
147+
return credentialMap;
148+
}
149+
150+
/**
151+
* Handles password authentication (backward compatible).
152+
*/
153+
private Map<String, String> handlePasswordAuthentication(Map<String, String> oauthConfig)
154+
{
155+
Map<String, String> credentialMap = new HashMap<>();
156+
credentialMap.put(SnowflakeConstants.USER, getUsername(oauthConfig));
157+
credentialMap.put(SnowflakeConstants.PASSWORD, oauthConfig.get(SnowflakeConstants.PASSWORD));
158+
LOGGER.debug("Using password authentication for user: {}", getUsername(oauthConfig));
159+
return credentialMap;
160+
}
161+
121162
private String loadTokenFromSecretsManager(Map<String, String> oauthConfig)
122163
{
123164
if (oauthConfig.containsKey(ACCESS_TOKEN)) {
@@ -126,20 +167,27 @@ private String loadTokenFromSecretsManager(Map<String, String> oauthConfig)
126167
return null;
127168
}
128169

129-
private void saveTokenToSecretsManager(JSONObject tokenJson, Map<String, String> oauthConfig)
170+
private void saveTokenToSecretsManager(ObjectNode tokenJson, Map<String, String> oauthConfig)
130171
{
131172
// Update token related fields
132173
tokenJson.put(FETCHED_AT, System.currentTimeMillis() / 1000);
133-
oauthConfig.put(ACCESS_TOKEN, tokenJson.getString(ACCESS_TOKEN));
134-
oauthConfig.put(REFRESH_TOKEN, tokenJson.getString(REFRESH_TOKEN));
135-
oauthConfig.put(EXPIRES_IN, String.valueOf(tokenJson.getInt(EXPIRES_IN)));
136-
oauthConfig.put(FETCHED_AT, String.valueOf(tokenJson.getLong(FETCHED_AT)));
174+
oauthConfig.put(ACCESS_TOKEN, tokenJson.get(ACCESS_TOKEN).asText());
175+
oauthConfig.put(REFRESH_TOKEN, tokenJson.get(REFRESH_TOKEN).asText());
176+
oauthConfig.put(EXPIRES_IN, String.valueOf(tokenJson.get(EXPIRES_IN).asInt()));
177+
oauthConfig.put(FETCHED_AT, String.valueOf(tokenJson.get(FETCHED_AT).asLong()));
137178

138179
// Save updated secret
139-
secretsManager.getSecretsManager().putSecretValue(builder -> builder
140-
.secretId(this.oauthSecretName)
141-
.secretString(String.valueOf(new JSONObject(oauthConfig)))
142-
.build());
180+
try {
181+
String updatedSecretString = objectMapper.writeValueAsString(oauthConfig);
182+
secretsManager.getSecretsManager().putSecretValue(builder -> builder
183+
.secretId(this.oauthSecretName)
184+
.secretString(updatedSecretString)
185+
.build());
186+
}
187+
catch (Exception e) {
188+
LOGGER.error("Failed to save updated secret: ", e);
189+
throw new RuntimeException("Failed to save updated secret: ", e);
190+
}
143191
}
144192

145193
private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throws Exception
@@ -155,9 +203,9 @@ private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throw
155203

156204
if (accessToken == null) {
157205
LOGGER.debug("First time auth. Using authorization_code...");
158-
JSONObject tokenJson = getTokenFromAuthCode(authCode, redirectUri, tokenEndpoint, clientId, clientSecret);
206+
ObjectNode tokenJson = getTokenFromAuthCode(authCode, redirectUri, tokenEndpoint, clientId, clientSecret);
159207
saveTokenToSecretsManager(tokenJson, oauthConfig);
160-
accessToken = tokenJson.getString(ACCESS_TOKEN);
208+
accessToken = tokenJson.get(ACCESS_TOKEN).asText();
161209
}
162210
else {
163211
long expiresIn = Long.parseLong(oauthConfig.get(EXPIRES_IN));
@@ -169,16 +217,16 @@ private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throw
169217
}
170218
else {
171219
LOGGER.debug("Access token expired. Using refresh_token...");
172-
JSONObject refreshed = refreshAccessToken(oauthConfig.get(REFRESH_TOKEN), tokenEndpoint, clientId, clientSecret);
220+
ObjectNode refreshed = refreshAccessToken(oauthConfig.get(REFRESH_TOKEN), tokenEndpoint, clientId, clientSecret);
173221
refreshed.put(REFRESH_TOKEN, oauthConfig.get(REFRESH_TOKEN));
174222
saveTokenToSecretsManager(refreshed, oauthConfig);
175-
accessToken = refreshed.getString(ACCESS_TOKEN);
223+
accessToken = refreshed.get(ACCESS_TOKEN).asText();
176224
}
177225
}
178226
return accessToken;
179227
}
180228

181-
private JSONObject getTokenFromAuthCode(String authCode, String redirectUri, String tokenEndpoint, String clientId, String clientSecret) throws Exception
229+
private ObjectNode getTokenFromAuthCode(String authCode, String redirectUri, String tokenEndpoint, String clientId, String clientSecret) throws Exception
182230
{
183231
String body = "grant_type=authorization_code"
184232
+ "&code=" + authCode
@@ -187,15 +235,15 @@ private JSONObject getTokenFromAuthCode(String authCode, String redirectUri, Str
187235
return requestToken(body, tokenEndpoint, clientId, clientSecret);
188236
}
189237

190-
private JSONObject refreshAccessToken(String refreshToken, String tokenEndpoint, String clientId, String clientSecret) throws Exception
238+
private ObjectNode refreshAccessToken(String refreshToken, String tokenEndpoint, String clientId, String clientSecret) throws Exception
191239
{
192240
String body = "grant_type=refresh_token"
193241
+ "&refresh_token=" + URLEncoder.encode(refreshToken, StandardCharsets.UTF_8);
194242

195243
return requestToken(body, tokenEndpoint, clientId, clientSecret);
196244
}
197245

198-
private JSONObject requestToken(String requestBody, String tokenEndpoint, String clientId, String clientSecret) throws Exception
246+
private ObjectNode requestToken(String requestBody, String tokenEndpoint, String clientId, String clientSecret) throws Exception
199247
{
200248
HttpURLConnection conn = getHttpURLConnection(tokenEndpoint, clientId, clientSecret);
201249

@@ -215,7 +263,7 @@ private JSONObject requestToken(String requestBody, String tokenEndpoint, String
215263
throw new RuntimeException("Failed: " + responseCode + " - " + response);
216264
}
217265

218-
JSONObject tokenJson = new JSONObject(response);
266+
ObjectNode tokenJson = objectMapper.readValue(response, ObjectNode.class);
219267
tokenJson.put(FETCHED_AT, System.currentTimeMillis() / 1000);
220268
return tokenJson;
221269
}

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@
4949
import com.amazonaws.athena.connector.util.PaginationHelper;
5050
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
5151
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
52-
import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory;
5352
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
5453
import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil;
5554
import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter;
5655
import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler;
5756
import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder;
5857
import com.amazonaws.athena.connectors.jdbc.resolver.JDBCCaseResolver;
58+
import com.amazonaws.athena.connectors.snowflake.connection.SnowflakeConnectionFactory;
5959
import com.amazonaws.athena.connectors.snowflake.resolver.SnowflakeJDBCCaseResolver;
6060
import com.google.common.annotations.VisibleForTesting;
6161
import com.google.common.base.Strings;
@@ -154,7 +154,7 @@ public SnowflakeMetadataHandler(java.util.Map<String, String> configOptions)
154154
public SnowflakeMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map<String, String> configOptions)
155155
{
156156
this(databaseConnectionConfig,
157-
new GenericJdbcConnectionFactory(databaseConnectionConfig, SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions),
157+
new SnowflakeConnectionFactory(databaseConnectionConfig, SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions),
158158
new DatabaseConnectionInfo(SnowflakeConstants.SNOWFLAKE_DRIVER_CLASS, SnowflakeConstants.SNOWFLAKE_DEFAULT_PORT)),
159159
configOptions);
160160
}

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException;
2727
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
2828
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
29-
import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory;
3029
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
3130
import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil;
3231
import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler;
3332
import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder;
33+
import com.amazonaws.athena.connectors.snowflake.connection.SnowflakeConnectionFactory;
3434
import com.google.common.annotations.VisibleForTesting;
3535
import org.apache.arrow.vector.types.pojo.Schema;
3636
import org.apache.commons.lang3.StringUtils;
@@ -63,10 +63,10 @@ public SnowflakeRecordHandler(java.util.Map<String, String> configOptions)
6363
}
6464
public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map<String, String> configOptions)
6565
{
66-
this(databaseConnectionConfig, new GenericJdbcConnectionFactory(databaseConnectionConfig, SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions),
66+
this(databaseConnectionConfig, new SnowflakeConnectionFactory(databaseConnectionConfig, SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions),
6767
new DatabaseConnectionInfo(SnowflakeConstants.SNOWFLAKE_DRIVER_CLASS, SnowflakeConstants.SNOWFLAKE_DEFAULT_PORT)), configOptions);
6868
}
69-
public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, GenericJdbcConnectionFactory jdbcConnectionFactory, java.util.Map<String, String> configOptions)
69+
public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, SnowflakeConnectionFactory jdbcConnectionFactory, java.util.Map<String, String> configOptions)
7070
{
7171
this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(),
7272
jdbcConnectionFactory, new SnowflakeQueryStringBuilder(SNOWFLAKE_QUOTE_CHARACTER, new SnowflakeFederationExpressionParser(SNOWFLAKE_QUOTE_CHARACTER)), configOptions);

0 commit comments

Comments
 (0)