generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 333
Implement Key-Pair Authentication for Athena Snowflake Connector #2857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
chngpe
merged 11 commits into
awslabs:master
from
Trianz-Akshay:feature/snowflake-oauth-keypair-integration
Aug 15, 2025
Merged
Changes from 6 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
5626c8b
feat: implement key-pair authentication for athena-snowflake connector
ozakidai 52c9ecd
feat: implement test code
ozakidai ff452c8
fixed based on review comments
ozakidai 43aa4ce
key-pair auth integration
Trianz-Akshay fbb5c7e
Merge branch 'feature/Athena-Snowflake-privatekey' into feature/snowf…
Trianz-Akshay 06e9c03
key-pair auth integration
Trianz-Akshay b2e0ee8
review comment changes
Trianz-Akshay f01ff4e
handle encrypted privateKey
Trianz-Akshay 8940f03
review comment changes
Trianz-Akshay b02ae83
review comment changes
Trianz-Akshay 4977c99
Merge branch 'master' into feature/snowflake-oauth-keypair-integration
chngpe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,9 +22,11 @@ | |
| import com.amazonaws.athena.connector.credentials.CredentialsProvider; | ||
| import com.amazonaws.athena.connector.credentials.DefaultCredentials; | ||
| import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager; | ||
| import com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthType; | ||
| import com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthUtils; | ||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||
| import com.fasterxml.jackson.databind.node.ObjectNode; | ||
| import com.google.common.annotations.VisibleForTesting; | ||
| import org.json.JSONObject; | ||
| import org.slf4j.Logger; | ||
| import org.slf4j.LoggerFactory; | ||
| import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; | ||
|
|
@@ -43,10 +45,13 @@ | |
| import java.util.HashMap; | ||
| import java.util.Map; | ||
|
|
||
| import static com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthType.OAUTH; | ||
| import static com.amazonaws.athena.connectors.snowflake.utils.SnowflakeAuthUtils.getUsername; | ||
|
|
||
| /** | ||
| * Snowflake OAuth credentials provider that manages OAuth token lifecycle. | ||
| * This provider handles token refresh, expiration, and provides credential properties | ||
| * for Snowflake OAuth connections. | ||
| * Snowflake credentials provider that manages multiple authentication methods. | ||
| * This provider handles OAuth token lifecycle, key-pair authentication, and password authentication. | ||
| * Authentication method is automatically determined based on the secret contents. | ||
| */ | ||
| public class SnowflakeCredentialsProvider implements CredentialsProvider | ||
| { | ||
|
|
@@ -56,10 +61,6 @@ public class SnowflakeCredentialsProvider implements CredentialsProvider | |
| public static final String FETCHED_AT = "fetched_at"; | ||
| public static final String REFRESH_TOKEN = "refresh_token"; | ||
| public static final String EXPIRES_IN = "expires_in"; | ||
| public static final String USERNAME = "username"; | ||
| public static final String PASSWORD = "password"; | ||
| public static final String USER = "user"; | ||
|
|
||
| private final String oauthSecretName; | ||
| private final CachableSecretsManager secretsManager; | ||
| private final ObjectMapper objectMapper; | ||
|
|
@@ -82,8 +83,8 @@ public DefaultCredentials getCredential() | |
| { | ||
| Map<String, String> credentialMap = getCredentialMap(); | ||
| return new DefaultCredentials( | ||
| credentialMap.get(USER), | ||
| credentialMap.get(PASSWORD) | ||
| credentialMap.get(SnowflakeConstants.USER), | ||
| credentialMap.get(SnowflakeConstants.PASSWORD) | ||
| ); | ||
| } | ||
|
|
||
|
|
@@ -94,30 +95,67 @@ public Map<String, String> getCredentialMap() | |
| String secretString = secretsManager.getSecret(oauthSecretName); | ||
| Map<String, String> oauthConfig = objectMapper.readValue(secretString, Map.class); | ||
|
||
|
|
||
| if (oauthConfig.containsKey(SnowflakeConstants.AUTH_CODE) && !oauthConfig.get(SnowflakeConstants.AUTH_CODE).isEmpty()) { | ||
| // OAuth flow | ||
| String accessToken = fetchAccessTokenFromSecret(oauthConfig); | ||
|
|
||
| Map<String, String> credentialMap = new HashMap<>(); | ||
| credentialMap.put(USER, oauthConfig.get(USERNAME)); | ||
| credentialMap.put(PASSWORD, accessToken); | ||
| credentialMap.put("authenticator", "oauth"); | ||
|
|
||
| return credentialMap; | ||
| } | ||
| else { | ||
| // Fallback to standard credentials | ||
| return Map.of( | ||
| USER, oauthConfig.get(USERNAME), | ||
| PASSWORD, oauthConfig.get(PASSWORD) | ||
| ); | ||
| // Determine authentication type based on secret contents | ||
| SnowflakeAuthType authType = SnowflakeAuthUtils.determineAuthType(oauthConfig); | ||
|
|
||
| switch (authType) { | ||
| case SNOWFLAKE_JWT: | ||
| // Key-pair authentication | ||
| return handleKeyPairAuthentication(oauthConfig); | ||
| case OAUTH: | ||
| // OAuth flow | ||
| return handleOAuthAuthentication(oauthConfig); | ||
| case SNOWFLAKE: | ||
| default: | ||
| // Password authentication (backward compatible) | ||
| return handlePasswordAuthentication(oauthConfig); | ||
| } | ||
| } | ||
| catch (Exception ex) { | ||
| throw new RuntimeException("Error retrieving Snowflake credentials: " + ex.getMessage(), ex); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Handles key-pair authentication. | ||
| */ | ||
| private Map<String, String> handleKeyPairAuthentication(Map<String, String> oauthConfig) | ||
| { | ||
| Map<String, String> credentialMap = new HashMap<>(); | ||
| String username = getUsername(oauthConfig); | ||
| credentialMap.put(USER, username); | ||
| credentialMap.put(SnowflakeConstants.PRIVATE_KEY, oauthConfig.get(SnowflakeConstants.PRIVATE_KEY)); | ||
| LOGGER.debug("Using key-pair authentication for user: {}", username); | ||
| return credentialMap; | ||
| } | ||
|
|
||
| /** | ||
| * Handles OAuth authentication. | ||
| */ | ||
| private Map<String, String> handleOAuthAuthentication(Map<String, String> oauthConfig) throws Exception | ||
| { | ||
| String accessToken = fetchAccessTokenFromSecret(oauthConfig); | ||
| Map<String, String> credentialMap = new HashMap<>(); | ||
| String username = getUsername(oauthConfig); | ||
| credentialMap.put(SnowflakeConstants.USER, username); | ||
| credentialMap.put(SnowflakeConstants.PASSWORD, accessToken); | ||
| credentialMap.put(SnowflakeConstants.AUTHENTICATOR, OAUTH.getValue()); | ||
| LOGGER.debug("Using OAuth authentication for user: {}", username); | ||
| return credentialMap; | ||
| } | ||
|
|
||
| /** | ||
| * Handles password authentication (backward compatible). | ||
| */ | ||
| private Map<String, String> handlePasswordAuthentication(Map<String, String> oauthConfig) | ||
| { | ||
| Map<String, String> credentialMap = new HashMap<>(); | ||
| credentialMap.put(SnowflakeConstants.USER, getUsername(oauthConfig)); | ||
| credentialMap.put(SnowflakeConstants.PASSWORD, oauthConfig.get(SnowflakeConstants.PASSWORD)); | ||
| LOGGER.debug("Using password authentication for user: {}", getUsername(oauthConfig)); | ||
| return credentialMap; | ||
| } | ||
|
|
||
| private String loadTokenFromSecretsManager(Map<String, String> oauthConfig) | ||
| { | ||
| if (oauthConfig.containsKey(ACCESS_TOKEN)) { | ||
|
|
@@ -126,20 +164,27 @@ private String loadTokenFromSecretsManager(Map<String, String> oauthConfig) | |
| return null; | ||
| } | ||
|
|
||
| private void saveTokenToSecretsManager(JSONObject tokenJson, Map<String, String> oauthConfig) | ||
| private void saveTokenToSecretsManager(ObjectNode tokenJson, Map<String, String> oauthConfig) | ||
| { | ||
| // Update token related fields | ||
| tokenJson.put(FETCHED_AT, System.currentTimeMillis() / 1000); | ||
| oauthConfig.put(ACCESS_TOKEN, tokenJson.getString(ACCESS_TOKEN)); | ||
| oauthConfig.put(REFRESH_TOKEN, tokenJson.getString(REFRESH_TOKEN)); | ||
| oauthConfig.put(EXPIRES_IN, String.valueOf(tokenJson.getInt(EXPIRES_IN))); | ||
| oauthConfig.put(FETCHED_AT, String.valueOf(tokenJson.getLong(FETCHED_AT))); | ||
| oauthConfig.put(ACCESS_TOKEN, tokenJson.get(ACCESS_TOKEN).asText()); | ||
| oauthConfig.put(REFRESH_TOKEN, tokenJson.get(REFRESH_TOKEN).asText()); | ||
| oauthConfig.put(EXPIRES_IN, String.valueOf(tokenJson.get(EXPIRES_IN).asInt())); | ||
| oauthConfig.put(FETCHED_AT, String.valueOf(tokenJson.get(FETCHED_AT).asLong())); | ||
|
|
||
| // Save updated secret | ||
| secretsManager.getSecretsManager().putSecretValue(builder -> builder | ||
| .secretId(this.oauthSecretName) | ||
| .secretString(String.valueOf(new JSONObject(oauthConfig))) | ||
| .build()); | ||
| try { | ||
| String updatedSecretString = objectMapper.writeValueAsString(oauthConfig); | ||
| secretsManager.getSecretsManager().putSecretValue(builder -> builder | ||
| .secretId(this.oauthSecretName) | ||
| .secretString(updatedSecretString) | ||
| .build()); | ||
| } | ||
| catch (Exception e) { | ||
| LOGGER.error("Failed to save updated secret: ", e); | ||
| throw new RuntimeException("Failed to save updated secret: ", e); | ||
| } | ||
| } | ||
|
|
||
| private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throws Exception | ||
|
|
@@ -155,9 +200,9 @@ private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throw | |
|
|
||
| if (accessToken == null) { | ||
| LOGGER.debug("First time auth. Using authorization_code..."); | ||
| JSONObject tokenJson = getTokenFromAuthCode(authCode, redirectUri, tokenEndpoint, clientId, clientSecret); | ||
| ObjectNode tokenJson = getTokenFromAuthCode(authCode, redirectUri, tokenEndpoint, clientId, clientSecret); | ||
| saveTokenToSecretsManager(tokenJson, oauthConfig); | ||
| accessToken = tokenJson.getString(ACCESS_TOKEN); | ||
| accessToken = tokenJson.get(ACCESS_TOKEN).asText(); | ||
| } | ||
| else { | ||
| long expiresIn = Long.parseLong(oauthConfig.get(EXPIRES_IN)); | ||
|
|
@@ -169,16 +214,16 @@ private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throw | |
| } | ||
| else { | ||
| LOGGER.debug("Access token expired. Using refresh_token..."); | ||
| JSONObject refreshed = refreshAccessToken(oauthConfig.get(REFRESH_TOKEN), tokenEndpoint, clientId, clientSecret); | ||
| ObjectNode refreshed = refreshAccessToken(oauthConfig.get(REFRESH_TOKEN), tokenEndpoint, clientId, clientSecret); | ||
| refreshed.put(REFRESH_TOKEN, oauthConfig.get(REFRESH_TOKEN)); | ||
| saveTokenToSecretsManager(refreshed, oauthConfig); | ||
| accessToken = refreshed.getString(ACCESS_TOKEN); | ||
| accessToken = refreshed.get(ACCESS_TOKEN).asText(); | ||
| } | ||
| } | ||
| return accessToken; | ||
| } | ||
|
|
||
| private JSONObject getTokenFromAuthCode(String authCode, String redirectUri, String tokenEndpoint, String clientId, String clientSecret) throws Exception | ||
| private ObjectNode getTokenFromAuthCode(String authCode, String redirectUri, String tokenEndpoint, String clientId, String clientSecret) throws Exception | ||
| { | ||
| String body = "grant_type=authorization_code" | ||
| + "&code=" + authCode | ||
|
|
@@ -187,15 +232,15 @@ private JSONObject getTokenFromAuthCode(String authCode, String redirectUri, Str | |
| return requestToken(body, tokenEndpoint, clientId, clientSecret); | ||
| } | ||
|
|
||
| private JSONObject refreshAccessToken(String refreshToken, String tokenEndpoint, String clientId, String clientSecret) throws Exception | ||
| private ObjectNode refreshAccessToken(String refreshToken, String tokenEndpoint, String clientId, String clientSecret) throws Exception | ||
| { | ||
| String body = "grant_type=refresh_token" | ||
| + "&refresh_token=" + URLEncoder.encode(refreshToken, StandardCharsets.UTF_8); | ||
|
|
||
| return requestToken(body, tokenEndpoint, clientId, clientSecret); | ||
| } | ||
|
|
||
| private JSONObject requestToken(String requestBody, String tokenEndpoint, String clientId, String clientSecret) throws Exception | ||
| private ObjectNode requestToken(String requestBody, String tokenEndpoint, String clientId, String clientSecret) throws Exception | ||
| { | ||
| HttpURLConnection conn = getHttpURLConnection(tokenEndpoint, clientId, clientSecret); | ||
|
|
||
|
|
@@ -215,7 +260,7 @@ private JSONObject requestToken(String requestBody, String tokenEndpoint, String | |
| throw new RuntimeException("Failed: " + responseCode + " - " + response); | ||
| } | ||
|
|
||
| JSONObject tokenJson = new JSONObject(response); | ||
| ObjectNode tokenJson = objectMapper.readValue(response, ObjectNode.class); | ||
| tokenJson.put(FETCHED_AT, System.currentTimeMillis() / 1000); | ||
| return tokenJson; | ||
| } | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw from comment
Gets the username from credentials, checking both "username" and "user" fields.Who is using the field
user?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I synced with our internal team, we will be using
sfUserandpem_private_keyas keys in secret manager. Can we please using those 2 keys for key-pair auth?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
userfield is a JDBC parameter used by Snowflake for authentication.Previously, the
getUsernamemethod was also invoked during the validation flow, which is why bothusernameanduserfields were being checked.With the recent simplification of the flow, credential validation is now handled within
SnowflakeCredentialsProviderbefore setting up the credentials, eliminating the need to check both fields at a later stage.Additionally, I have updated the secret structure to use the following fields:
{ "sfUser": "your_snowflake_user", "pem_private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----", "pem_private_key_passphrase": "passphrase_in_case_of_encrypted_private_key (optional)" }Currently,
sfUseris used only for key-pair authentication. If you expect it to be used for other authentication mechanisms as well, please let me know and I will align the implementation accordingly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we are using pem_private key, shouldn't we delete
PRIVATE_KEY = "privateKey";?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can we group the static variable and put comment on which one used for? like
sfUserandpem*are for key-pari,ASDFfor oAuthThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated comments on constant as suggested.