Skip to content
Merged
6 changes: 0 additions & 6 deletions athena-snowflake/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@
<artifactId>snowflake-jdbc</artifactId>
<version>3.24.2</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.json/json -->
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20250107</version>
</dependency>
<!-- https://mvnrepository.com/artifact/software.amazon.awssdk/rds -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public final class SnowflakeConstants
public static final String TOKEN_URL = "token_url";
public static final String REDIRECT_URI = "redirect_uri";
public static final String CLIENT_SECRET = "client_secret";
public static final String PRIVATE_KEY = "privateKey";
public static final String AUTHENTICATOR = "authenticator";
public static final String USERNAME = "username";
public static final String PASSWORD = "password";
public static final String USER = "user";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw from comment Gets the username from credentials, checking both "username" and "user" fields.

Who is using the field user?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I synced with our internal team, we will be using sfUser and pem_private_key as keys in secret manager. Can we please using those 2 keys for key-pair auth?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user field is a JDBC parameter used by Snowflake for authentication.
Previously, the getUsername method was also invoked during the validation flow, which is why both username and user fields were being checked.

With the recent simplification of the flow, credential validation is now handled within SnowflakeCredentialsProvider before setting up the credentials, eliminating the need to check both fields at a later stage.

Additionally, I have updated the secret structure to use the following fields:

{
  "sfUser": "your_snowflake_user",
  "pem_private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----",
  "pem_private_key_passphrase": "passphrase_in_case_of_encrypted_private_key (optional)"
}

Currently, sfUser is used only for key-pair authentication. If you expect it to be used for other authentication mechanisms as well, please let me know and I will align the implementation accordingly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are using pem_private key, shouldn't we delete PRIVATE_KEY = "privateKey";?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can we group the static variable and put comment on which one used for? like sfUser and pem* are for key-pari, ASDF for oAuth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated comments on constant as suggested.


private SnowflakeConstants() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import com.amazonaws.athena.connector.credentials.CredentialsProvider;
import com.amazonaws.athena.connector.credentials.DefaultCredentials;
import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager;
import com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthType;
import com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.annotations.VisibleForTesting;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
Expand All @@ -43,10 +45,13 @@
import java.util.HashMap;
import java.util.Map;

import static com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthType.OAUTH;
import static com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthUtils.getUsername;

/**
* Snowflake OAuth credentials provider that manages OAuth token lifecycle.
* This provider handles token refresh, expiration, and provides credential properties
* for Snowflake OAuth connections.
* Snowflake credentials provider that manages multiple authentication methods.
* This provider handles OAuth token lifecycle, key-pair authentication, and password authentication.
* Authentication method is automatically determined based on the secret contents.
*/
public class SnowflakeCredentialsProvider implements CredentialsProvider
{
Expand All @@ -56,10 +61,6 @@ public class SnowflakeCredentialsProvider implements CredentialsProvider
public static final String FETCHED_AT = "fetched_at";
public static final String REFRESH_TOKEN = "refresh_token";
public static final String EXPIRES_IN = "expires_in";
public static final String USERNAME = "username";
public static final String PASSWORD = "password";
public static final String USER = "user";

private final String oauthSecretName;
private final CachableSecretsManager secretsManager;
private final ObjectMapper objectMapper;
Expand All @@ -82,8 +83,8 @@ public DefaultCredentials getCredential()
{
Map<String, String> credentialMap = getCredentialMap();
return new DefaultCredentials(
credentialMap.get(USER),
credentialMap.get(PASSWORD)
credentialMap.get(SnowflakeConstants.USER),
credentialMap.get(SnowflakeConstants.PASSWORD)
);
}

Expand All @@ -94,30 +95,67 @@ public Map<String, String> getCredentialMap()
String secretString = secretsManager.getSecret(oauthSecretName);
Map<String, String> oauthConfig = objectMapper.readValue(secretString, Map.class);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rename this as secret_map or something to avoid confusion please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated code as suggested.


if (oauthConfig.containsKey(SnowflakeConstants.AUTH_CODE) && !oauthConfig.get(SnowflakeConstants.AUTH_CODE).isEmpty()) {
// OAuth flow
String accessToken = fetchAccessTokenFromSecret(oauthConfig);

Map<String, String> credentialMap = new HashMap<>();
credentialMap.put(USER, oauthConfig.get(USERNAME));
credentialMap.put(PASSWORD, accessToken);
credentialMap.put("authenticator", "oauth");

return credentialMap;
}
else {
// Fallback to standard credentials
return Map.of(
USER, oauthConfig.get(USERNAME),
PASSWORD, oauthConfig.get(PASSWORD)
);
// Determine authentication type based on secret contents
SnowflakeAuthType authType = SnowflakeAuthUtils.determineAuthType(oauthConfig);

switch (authType) {
case SNOWFLAKE_JWT:
// Key-pair authentication
return handleKeyPairAuthentication(oauthConfig);
case OAUTH:
// OAuth flow
return handleOAuthAuthentication(oauthConfig);
case SNOWFLAKE:
default:
// Password authentication (backward compatible)
return handlePasswordAuthentication(oauthConfig);
}
}
catch (Exception ex) {
throw new RuntimeException("Error retrieving Snowflake credentials: " + ex.getMessage(), ex);
}
}

/**
* Handles key-pair authentication.
*/
private Map<String, String> handleKeyPairAuthentication(Map<String, String> oauthConfig)
{
Map<String, String> credentialMap = new HashMap<>();
String username = getUsername(oauthConfig);
credentialMap.put(USER, username);
credentialMap.put(SnowflakeConstants.PRIVATE_KEY, oauthConfig.get(SnowflakeConstants.PRIVATE_KEY));
LOGGER.debug("Using key-pair authentication for user: {}", username);
return credentialMap;
}

/**
* Handles OAuth authentication.
*/
private Map<String, String> handleOAuthAuthentication(Map<String, String> oauthConfig) throws Exception
{
String accessToken = fetchAccessTokenFromSecret(oauthConfig);
Map<String, String> credentialMap = new HashMap<>();
String username = getUsername(oauthConfig);
credentialMap.put(SnowflakeConstants.USER, username);
credentialMap.put(SnowflakeConstants.PASSWORD, accessToken);
credentialMap.put(SnowflakeConstants.AUTHENTICATOR, OAUTH.getValue());
LOGGER.debug("Using OAuth authentication for user: {}", username);
return credentialMap;
}

/**
* Handles password authentication (backward compatible).
*/
private Map<String, String> handlePasswordAuthentication(Map<String, String> oauthConfig)
{
Map<String, String> credentialMap = new HashMap<>();
credentialMap.put(SnowflakeConstants.USER, getUsername(oauthConfig));
credentialMap.put(SnowflakeConstants.PASSWORD, oauthConfig.get(SnowflakeConstants.PASSWORD));
LOGGER.debug("Using password authentication for user: {}", getUsername(oauthConfig));
return credentialMap;
}

private String loadTokenFromSecretsManager(Map<String, String> oauthConfig)
{
if (oauthConfig.containsKey(ACCESS_TOKEN)) {
Expand All @@ -126,20 +164,27 @@ private String loadTokenFromSecretsManager(Map<String, String> oauthConfig)
return null;
}

private void saveTokenToSecretsManager(JSONObject tokenJson, Map<String, String> oauthConfig)
private void saveTokenToSecretsManager(ObjectNode tokenJson, Map<String, String> oauthConfig)
{
// Update token related fields
tokenJson.put(FETCHED_AT, System.currentTimeMillis() / 1000);
oauthConfig.put(ACCESS_TOKEN, tokenJson.getString(ACCESS_TOKEN));
oauthConfig.put(REFRESH_TOKEN, tokenJson.getString(REFRESH_TOKEN));
oauthConfig.put(EXPIRES_IN, String.valueOf(tokenJson.getInt(EXPIRES_IN)));
oauthConfig.put(FETCHED_AT, String.valueOf(tokenJson.getLong(FETCHED_AT)));
oauthConfig.put(ACCESS_TOKEN, tokenJson.get(ACCESS_TOKEN).asText());
oauthConfig.put(REFRESH_TOKEN, tokenJson.get(REFRESH_TOKEN).asText());
oauthConfig.put(EXPIRES_IN, String.valueOf(tokenJson.get(EXPIRES_IN).asInt()));
oauthConfig.put(FETCHED_AT, String.valueOf(tokenJson.get(FETCHED_AT).asLong()));

// Save updated secret
secretsManager.getSecretsManager().putSecretValue(builder -> builder
.secretId(this.oauthSecretName)
.secretString(String.valueOf(new JSONObject(oauthConfig)))
.build());
try {
String updatedSecretString = objectMapper.writeValueAsString(oauthConfig);
secretsManager.getSecretsManager().putSecretValue(builder -> builder
.secretId(this.oauthSecretName)
.secretString(updatedSecretString)
.build());
}
catch (Exception e) {
LOGGER.error("Failed to save updated secret: ", e);
throw new RuntimeException("Failed to save updated secret: ", e);
}
}

private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throws Exception
Expand All @@ -155,9 +200,9 @@ private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throw

if (accessToken == null) {
LOGGER.debug("First time auth. Using authorization_code...");
JSONObject tokenJson = getTokenFromAuthCode(authCode, redirectUri, tokenEndpoint, clientId, clientSecret);
ObjectNode tokenJson = getTokenFromAuthCode(authCode, redirectUri, tokenEndpoint, clientId, clientSecret);
saveTokenToSecretsManager(tokenJson, oauthConfig);
accessToken = tokenJson.getString(ACCESS_TOKEN);
accessToken = tokenJson.get(ACCESS_TOKEN).asText();
}
else {
long expiresIn = Long.parseLong(oauthConfig.get(EXPIRES_IN));
Expand All @@ -169,16 +214,16 @@ private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throw
}
else {
LOGGER.debug("Access token expired. Using refresh_token...");
JSONObject refreshed = refreshAccessToken(oauthConfig.get(REFRESH_TOKEN), tokenEndpoint, clientId, clientSecret);
ObjectNode refreshed = refreshAccessToken(oauthConfig.get(REFRESH_TOKEN), tokenEndpoint, clientId, clientSecret);
refreshed.put(REFRESH_TOKEN, oauthConfig.get(REFRESH_TOKEN));
saveTokenToSecretsManager(refreshed, oauthConfig);
accessToken = refreshed.getString(ACCESS_TOKEN);
accessToken = refreshed.get(ACCESS_TOKEN).asText();
}
}
return accessToken;
}

private JSONObject getTokenFromAuthCode(String authCode, String redirectUri, String tokenEndpoint, String clientId, String clientSecret) throws Exception
private ObjectNode getTokenFromAuthCode(String authCode, String redirectUri, String tokenEndpoint, String clientId, String clientSecret) throws Exception
{
String body = "grant_type=authorization_code"
+ "&code=" + authCode
Expand All @@ -187,15 +232,15 @@ private JSONObject getTokenFromAuthCode(String authCode, String redirectUri, Str
return requestToken(body, tokenEndpoint, clientId, clientSecret);
}

private JSONObject refreshAccessToken(String refreshToken, String tokenEndpoint, String clientId, String clientSecret) throws Exception
private ObjectNode refreshAccessToken(String refreshToken, String tokenEndpoint, String clientId, String clientSecret) throws Exception
{
String body = "grant_type=refresh_token"
+ "&refresh_token=" + URLEncoder.encode(refreshToken, StandardCharsets.UTF_8);

return requestToken(body, tokenEndpoint, clientId, clientSecret);
}

private JSONObject requestToken(String requestBody, String tokenEndpoint, String clientId, String clientSecret) throws Exception
private ObjectNode requestToken(String requestBody, String tokenEndpoint, String clientId, String clientSecret) throws Exception
{
HttpURLConnection conn = getHttpURLConnection(tokenEndpoint, clientId, clientSecret);

Expand All @@ -215,7 +260,7 @@ private JSONObject requestToken(String requestBody, String tokenEndpoint, String
throw new RuntimeException("Failed: " + responseCode + " - " + response);
}

JSONObject tokenJson = new JSONObject(response);
ObjectNode tokenJson = objectMapper.readValue(response, ObjectNode.class);
tokenJson.put(FETCHED_AT, System.currentTimeMillis() / 1000);
return tokenJson;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@
import com.amazonaws.athena.connector.util.PaginationHelper;
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory;
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil;
import com.amazonaws.athena.connectors.jdbc.manager.JdbcArrowTypeConverter;
import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler;
import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder;
import com.amazonaws.athena.connectors.jdbc.resolver.JDBCCaseResolver;
import com.amazonaws.athena.connectors.snowflake.connection.SnowflakeConnectionFactory;
import com.amazonaws.athena.connectors.snowflake.resolver.SnowflakeJDBCCaseResolver;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
Expand Down Expand Up @@ -154,7 +154,7 @@ public SnowflakeMetadataHandler(java.util.Map<String, String> configOptions)
public SnowflakeMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map<String, String> configOptions)
{
this(databaseConnectionConfig,
new GenericJdbcConnectionFactory(databaseConnectionConfig, SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions),
new SnowflakeConnectionFactory(databaseConnectionConfig, SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions),
new DatabaseConnectionInfo(SnowflakeConstants.SNOWFLAKE_DRIVER_CLASS, SnowflakeConstants.SNOWFLAKE_DEFAULT_PORT)),
configOptions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException;
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig;
import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionInfo;
import com.amazonaws.athena.connectors.jdbc.connection.GenericJdbcConnectionFactory;
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil;
import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler;
import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder;
import com.amazonaws.athena.connectors.snowflake.connection.SnowflakeConnectionFactory;
import com.google.common.annotations.VisibleForTesting;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -63,10 +63,10 @@ public SnowflakeRecordHandler(java.util.Map<String, String> configOptions)
}
public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map<String, String> configOptions)
{
this(databaseConnectionConfig, new GenericJdbcConnectionFactory(databaseConnectionConfig, SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions),
this(databaseConnectionConfig, new SnowflakeConnectionFactory(databaseConnectionConfig, SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions),
new DatabaseConnectionInfo(SnowflakeConstants.SNOWFLAKE_DRIVER_CLASS, SnowflakeConstants.SNOWFLAKE_DEFAULT_PORT)), configOptions);
}
public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, GenericJdbcConnectionFactory jdbcConnectionFactory, java.util.Map<String, String> configOptions)
public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, SnowflakeConnectionFactory jdbcConnectionFactory, java.util.Map<String, String> configOptions)
{
this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(),
jdbcConnectionFactory, new SnowflakeQueryStringBuilder(SNOWFLAKE_QUOTE_CHARACTER, new SnowflakeFederationExpressionParser(SNOWFLAKE_QUOTE_CHARACTER)), configOptions);
Expand Down
Loading