diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/credentials/CredentialsProvider.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/credentials/CredentialsProvider.java index 838e7a0e05..b2c377fde2 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/credentials/CredentialsProvider.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/credentials/CredentialsProvider.java @@ -19,15 +19,47 @@ */ package com.amazonaws.athena.connector.credentials; +import java.util.Map; + /** * JDBC username and password provider. */ public interface CredentialsProvider { + String USER = "user"; + String PASSWORD = "password"; + /** * Retrieves credential for database. * * @return JDBC credential. See {@link DefaultCredentials}. */ DefaultCredentials getCredential(); + + /** + * Retrieves credential properties as a map for database connection. + * + * Default Behavior: + * The default implementation returns a map containing only the basic "user" and "password" + * properties extracted from the {@link DefaultCredentials} object returned by {@link #getCredential()}. + * This maintains backward compatibility with existing JDBC connection patterns. + * + * Extended Behavior: + * Implementations can override this method to provide additional connection properties beyond + * just username and password. This enables support for advanced authentication mechanisms. + * + * Usage: + * The returned map is directly applied to JDBC connection properties, allowing for seamless + * integration with various database drivers and authentication schemes without requiring + * custom connection factory implementations. + * + * @return Map containing credential properties for database connection. The default implementation + * returns a map with "user" and "password" keys. Overriding implementations may return + * additional properties as needed for their specific authentication requirements. + */ + default Map getCredentialMap() + { + DefaultCredentials credential = getCredential(); + return Map.of(USER, credential.getUser(), PASSWORD, credential.getPassword()); + } } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java index 07b4ac4d71..43f607ea71 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java @@ -61,6 +61,16 @@ public CachableSecretsManager(SecretsManagerClient secretsManager) this.secretsManager = secretsManager; } + /** + * Gets the underlying SecretsManagerClient instance. + * + * @return The SecretsManagerClient instance. + */ + public SecretsManagerClient getSecretsManager() + { + return secretsManager; + } + /** * Resolves any secrets found in the supplied string, for example: MyString${WithSecret} would have ${WithSecret} * repalced by the corresponding value of the secret in AWS Secrets Manager with that name. If no such secret is found diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/connection/GenericJdbcConnectionFactory.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/connection/GenericJdbcConnectionFactory.java index bb43ecef23..0d986a324b 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/connection/GenericJdbcConnectionFactory.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/connection/GenericJdbcConnectionFactory.java @@ -83,8 +83,7 @@ public Connection getConnection(final CredentialsProvider credentialsProvider) Matcher secretMatcher = SECRET_NAME_PATTERN.matcher(databaseConnectionConfig.getJdbcConnectionString()); derivedJdbcString = secretMatcher.replaceAll(Matcher.quoteReplacement("")); - jdbcProperties.put("user", credentialsProvider.getCredential().getUser()); - jdbcProperties.put("password", credentialsProvider.getCredential().getPassword()); + jdbcProperties.putAll(credentialsProvider.getCredentialMap()); } else { derivedJdbcString = databaseConnectionConfig.getJdbcConnectionString(); diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java index 56c2993f7d..ec60b73175 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java @@ -168,6 +168,11 @@ protected JdbcConnectionFactory getJdbcConnectionFactory() return jdbcConnectionFactory; } + protected DatabaseConnectionConfig getDatabaseConnectionConfig() + { + return databaseConnectionConfig; + } + protected CredentialsProvider getCredentialProvider() { final String secretName = databaseConnectionConfig.getSecret(); diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java index 4ec824a1ca..2c82afd1fb 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java @@ -132,6 +132,11 @@ protected JdbcConnectionFactory getJdbcConnectionFactory() return jdbcConnectionFactory; } + protected DatabaseConnectionConfig getDatabaseConnectionConfig() + { + return databaseConnectionConfig; + } + protected CredentialsProvider getCredentialProvider() { final String secretName = this.databaseConnectionConfig.getSecret(); diff --git a/athena-snowflake/athena-snowflake-connection.yaml b/athena-snowflake/athena-snowflake-connection.yaml index e7261f178d..83eb9c136a 100644 --- a/athena-snowflake/athena-snowflake-connection.yaml +++ b/athena-snowflake/athena-snowflake-connection.yaml @@ -102,6 +102,7 @@ Resources: Statement: - Action: - secretsmanager:GetSecretValue + - secretsmanager:PutSecretValue Effect: Allow Resource: !Sub 'arn:${AWS::Partition}:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${SecretName}*' - Action: diff --git a/athena-snowflake/athena-snowflake.yaml b/athena-snowflake/athena-snowflake.yaml index 7b86121621..22a9030986 100644 --- a/athena-snowflake/athena-snowflake.yaml +++ b/athena-snowflake/athena-snowflake.yaml @@ -85,6 +85,7 @@ Resources: - Statement: - Action: - secretsmanager:GetSecretValue + - secretsmanager:PutSecretValue Effect: Allow Resource: !Sub 'arn:${AWS::Partition}:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${SecretNamePrefix}*' Version: '2012-10-17' diff --git a/athena-snowflake/pom.xml b/athena-snowflake/pom.xml index 79cceee448..d0c743b821 100644 --- a/athena-snowflake/pom.xml +++ b/athena-snowflake/pom.xml @@ -32,6 +32,12 @@ snowflake-jdbc 3.24.2 + + + org.json + json + 20250107 + software.amazon.awssdk diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeConstants.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeConstants.java index 9ff862c2df..a90b51b2c7 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeConstants.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeConstants.java @@ -35,6 +35,11 @@ public final class SnowflakeConstants */ public static final int SINGLE_SPLIT_LIMIT_COUNT = 10000; public static final String SNOWFLAKE_QUOTE_CHARACTER = "\""; + public static final String AUTH_CODE = "auth_code"; + public static final String CLIENT_ID = "client_id"; + public static final String TOKEN_URL = "token_url"; + public static final String REDIRECT_URI = "redirect_uri"; + public static final String CLIENT_SECRET = "client_secret"; private SnowflakeConstants() {} } diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeCredentialsProvider.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeCredentialsProvider.java new file mode 100644 index 0000000000..544104949f --- /dev/null +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeCredentialsProvider.java @@ -0,0 +1,237 @@ +/*- + * #%L + * athena-snowflake + * %% + * Copyright (C) 2019 - 2025 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.snowflake; + +import com.amazonaws.athena.connector.credentials.CredentialsProvider; +import com.amazonaws.athena.connector.credentials.DefaultCredentials; +import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.annotations.VisibleForTesting; +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.utils.Validate; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; + +/** + * Snowflake OAuth credentials provider that manages OAuth token lifecycle. + * This provider handles token refresh, expiration, and provides credential properties + * for Snowflake OAuth connections. + */ +public class SnowflakeCredentialsProvider implements CredentialsProvider +{ + private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeCredentialsProvider.class); + + public static final String ACCESS_TOKEN = "access_token"; + public static final String FETCHED_AT = "fetched_at"; + public static final String REFRESH_TOKEN = "refresh_token"; + public static final String EXPIRES_IN = "expires_in"; + public static final String USERNAME = "username"; + public static final String PASSWORD = "password"; + public static final String USER = "user"; + + private final String oauthSecretName; + private final CachableSecretsManager secretsManager; + private final ObjectMapper objectMapper; + + public SnowflakeCredentialsProvider(String oauthSecretName) + { + this(oauthSecretName, SecretsManagerClient.create()); + } + + @VisibleForTesting + public SnowflakeCredentialsProvider(String oauthSecretName, SecretsManagerClient secretsClient) + { + this.oauthSecretName = Validate.notNull(oauthSecretName, "oauthSecretName must not be null"); + this.secretsManager = new CachableSecretsManager(secretsClient); + this.objectMapper = new ObjectMapper(); + } + + @Override + public DefaultCredentials getCredential() + { + Map credentialMap = getCredentialMap(); + return new DefaultCredentials( + credentialMap.get(USER), + credentialMap.get(PASSWORD) + ); + } + + @Override + public Map getCredentialMap() + { + try { + String secretString = secretsManager.getSecret(oauthSecretName); + Map oauthConfig = objectMapper.readValue(secretString, Map.class); + + if (oauthConfig.containsKey(SnowflakeConstants.AUTH_CODE) && !oauthConfig.get(SnowflakeConstants.AUTH_CODE).isEmpty()) { + // OAuth flow + String accessToken = fetchAccessTokenFromSecret(oauthConfig); + + Map credentialMap = new HashMap<>(); + credentialMap.put(USER, oauthConfig.get(USERNAME)); + credentialMap.put(PASSWORD, accessToken); + credentialMap.put("authenticator", "oauth"); + + return credentialMap; + } + else { + // Fallback to standard credentials + return Map.of( + USER, oauthConfig.get(USERNAME), + PASSWORD, oauthConfig.get(PASSWORD) + ); + } + } + catch (Exception ex) { + throw new RuntimeException("Error retrieving Snowflake credentials: " + ex.getMessage(), ex); + } + } + + private String loadTokenFromSecretsManager(Map oauthConfig) + { + if (oauthConfig.containsKey(ACCESS_TOKEN)) { + return oauthConfig.get(ACCESS_TOKEN); + } + return null; + } + + private void saveTokenToSecretsManager(JSONObject tokenJson, Map oauthConfig) + { + // Update token related fields + tokenJson.put(FETCHED_AT, System.currentTimeMillis() / 1000); + oauthConfig.put(ACCESS_TOKEN, tokenJson.getString(ACCESS_TOKEN)); + oauthConfig.put(REFRESH_TOKEN, tokenJson.getString(REFRESH_TOKEN)); + oauthConfig.put(EXPIRES_IN, String.valueOf(tokenJson.getInt(EXPIRES_IN))); + oauthConfig.put(FETCHED_AT, String.valueOf(tokenJson.getLong(FETCHED_AT))); + + // Save updated secret + secretsManager.getSecretsManager().putSecretValue(builder -> builder + .secretId(this.oauthSecretName) + .secretString(String.valueOf(new JSONObject(oauthConfig))) + .build()); + } + + private String fetchAccessTokenFromSecret(Map oauthConfig) throws Exception + { + String accessToken; + String clientId = Validate.notNull(oauthConfig.get(SnowflakeConstants.CLIENT_ID), "Missing required property: client_id"); + String tokenEndpoint = Validate.notNull(oauthConfig.get(SnowflakeConstants.TOKEN_URL), "Missing required property: token_url"); + String redirectUri = Validate.notNull(oauthConfig.get(SnowflakeConstants.REDIRECT_URI), "Missing required property: redirect_uri"); + String clientSecret = Validate.notNull(oauthConfig.get(SnowflakeConstants.CLIENT_SECRET), "Missing required property: client_secret"); + String authCode = Validate.notNull(oauthConfig.get(SnowflakeConstants.AUTH_CODE), "Missing required property: auth_code"); + + accessToken = loadTokenFromSecretsManager(oauthConfig); + + if (accessToken == null) { + LOGGER.debug("First time auth. Using authorization_code..."); + JSONObject tokenJson = getTokenFromAuthCode(authCode, redirectUri, tokenEndpoint, clientId, clientSecret); + saveTokenToSecretsManager(tokenJson, oauthConfig); + accessToken = tokenJson.getString(ACCESS_TOKEN); + } + else { + long expiresIn = Long.parseLong(oauthConfig.get(EXPIRES_IN)); + long fetchedAt = Long.parseLong(oauthConfig.getOrDefault(FETCHED_AT, String.valueOf(0L))); + long now = System.currentTimeMillis() / 1000; + + if ((now - fetchedAt) < expiresIn - 60) { + LOGGER.debug("Access token still valid."); + } + else { + LOGGER.debug("Access token expired. Using refresh_token..."); + JSONObject refreshed = refreshAccessToken(oauthConfig.get(REFRESH_TOKEN), tokenEndpoint, clientId, clientSecret); + refreshed.put(REFRESH_TOKEN, oauthConfig.get(REFRESH_TOKEN)); + saveTokenToSecretsManager(refreshed, oauthConfig); + accessToken = refreshed.getString(ACCESS_TOKEN); + } + } + return accessToken; + } + + private JSONObject getTokenFromAuthCode(String authCode, String redirectUri, String tokenEndpoint, String clientId, String clientSecret) throws Exception + { + String body = "grant_type=authorization_code" + + "&code=" + authCode + + "&redirect_uri=" + redirectUri; + + return requestToken(body, tokenEndpoint, clientId, clientSecret); + } + + private JSONObject refreshAccessToken(String refreshToken, String tokenEndpoint, String clientId, String clientSecret) throws Exception + { + String body = "grant_type=refresh_token" + + "&refresh_token=" + URLEncoder.encode(refreshToken, StandardCharsets.UTF_8); + + return requestToken(body, tokenEndpoint, clientId, clientSecret); + } + + private JSONObject requestToken(String requestBody, String tokenEndpoint, String clientId, String clientSecret) throws Exception + { + HttpURLConnection conn = getHttpURLConnection(tokenEndpoint, clientId, clientSecret); + + try (OutputStream os = conn.getOutputStream()) { + os.write(requestBody.getBytes(StandardCharsets.UTF_8)); + } + + int responseCode = conn.getResponseCode(); + InputStream is = (responseCode >= 200 && responseCode < 300) ? + conn.getInputStream() : conn.getErrorStream(); + + String response = new BufferedReader(new InputStreamReader(is)) + .lines() + .reduce("", (acc, line) -> acc + line); + + if (responseCode != 200) { + throw new RuntimeException("Failed: " + responseCode + " - " + response); + } + + JSONObject tokenJson = new JSONObject(response); + tokenJson.put(FETCHED_AT, System.currentTimeMillis() / 1000); + return tokenJson; + } + + static HttpURLConnection getHttpURLConnection(String tokenEndpoint, String clientId, String clientSecret) throws IOException + { + URL url = new URL(tokenEndpoint); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + + String authHeader = Base64.getEncoder() + .encodeToString((clientId + ":" + clientSecret).getBytes(StandardCharsets.UTF_8)); + + conn.setRequestMethod("POST"); + conn.setRequestProperty("Authorization", "Basic " + authHeader); + conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded"); + conn.setDoOutput(true); + return conn; + } +} diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java index 7e82afa7fc..7a0455e0fc 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandler.java @@ -1,4 +1,3 @@ - /*- * #%L * athena-snowflake @@ -21,6 +20,7 @@ package com.amazonaws.athena.connectors.snowflake; +import com.amazonaws.athena.connector.credentials.CredentialsProvider; import com.amazonaws.athena.connector.lambda.QueryStatusChecker; import com.amazonaws.athena.connector.lambda.data.Block; import com.amazonaws.athena.connector.lambda.data.BlockAllocator; @@ -66,6 +66,7 @@ import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; @@ -527,4 +528,15 @@ protected Set listDatabaseNames(final Connection jdbcConnection) return schemaNames.build(); } } + + @Override + protected CredentialsProvider getCredentialProvider() + { + final String secretName = getDatabaseConnectionConfig().getSecret(); + if (StringUtils.isNotBlank(secretName)) { + return new SnowflakeCredentialsProvider(secretName); + } + + return null; + } } diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java index 3a5ea2c8aa..abbb55129c 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java @@ -1,4 +1,3 @@ - /*- * #%L * athena-snowflake @@ -20,6 +19,7 @@ */ package com.amazonaws.athena.connectors.snowflake; +import com.amazonaws.athena.connector.credentials.CredentialsProvider; import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; @@ -33,6 +33,7 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; import software.amazon.awssdk.services.glue.model.ErrorDetails; @@ -62,10 +63,8 @@ public SnowflakeRecordHandler(java.util.Map configOptions) } public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, new GenericJdbcConnectionFactory(databaseConnectionConfig, - SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions), - new DatabaseConnectionInfo(SnowflakeConstants.SNOWFLAKE_DRIVER_CLASS, - SnowflakeConstants.SNOWFLAKE_DEFAULT_PORT)), configOptions); + this(databaseConnectionConfig, new GenericJdbcConnectionFactory(databaseConnectionConfig, SnowflakeEnvironmentProperties.getSnowFlakeParameter(JDBC_PROPERTIES, configOptions), + new DatabaseConnectionInfo(SnowflakeConstants.SNOWFLAKE_DRIVER_CLASS, SnowflakeConstants.SNOWFLAKE_DEFAULT_PORT)), configOptions); } public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, GenericJdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { @@ -100,4 +99,15 @@ public PreparedStatement buildSplitSql(Connection jdbcConnection, String catalog } return preparedStatement; } + + @Override + protected CredentialsProvider getCredentialProvider() + { + final String secretName = getDatabaseConnectionConfig().getSecret(); + if (StringUtils.isNotBlank(secretName)) { + return new SnowflakeCredentialsProvider(secretName); + } + + return null; + } } diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeCredentialsProviderTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeCredentialsProviderTest.java new file mode 100644 index 0000000000..e346faae22 --- /dev/null +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeCredentialsProviderTest.java @@ -0,0 +1,494 @@ +/*- + * #%L + * athena-snowflake + * %% + * Copyright (C) 2019 - 2025 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.snowflake; + +import com.amazonaws.athena.connector.credentials.DefaultCredentials; +import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class SnowflakeCredentialsProviderTest +{ + private static final String TEST_SECRET_NAME = "test-oauth-secret"; + private static final String TEST_CLIENT_ID = "test-client-id"; + private static final String TEST_CLIENT_SECRET = "test-client-secret"; + private static final String TEST_TOKEN_URL = "https://test.snowflakecomputing.com/oauth/token-request"; + private static final String TEST_REDIRECT_URI = "https://test.com/callback"; + private static final String TEST_AUTH_CODE = "test-auth-code"; + private static final String TEST_USERNAME = "test-user"; + private static final String TEST_PASSWORD = "test-password"; + private static final String TEST_ACCESS_TOKEN = "test-access-token"; + private static final String TEST_REFRESH_TOKEN = "test-refresh-token"; + + private SecretsManagerClient mockSecretsClient; + + @Before + public void setUp() + { + mockSecretsClient = mock(SecretsManagerClient.class); + } + + @Test + public void testGetCredentialWithOAuthFlow() throws Exception + { + String secretJson = createOAuthSecretJson(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + try (MockedStatic mockedStatic = Mockito.mockStatic(SnowflakeCredentialsProvider.class)) { + HttpURLConnection mockConnection = createMockHttpConnection(200, createTokenResponse()); + mockedStatic.when(() -> SnowflakeCredentialsProvider.getHttpURLConnection(anyString(), anyString(), anyString())) + .thenReturn(mockConnection); + + DefaultCredentials credentials = provider.getCredential(); + + assertNotNull(credentials); + assertEquals(TEST_USERNAME, credentials.getUser()); + assertEquals(TEST_ACCESS_TOKEN, credentials.getPassword()); + } + } + } + + @Test + public void testGetCredentialMapWithOAuthFlow() throws Exception + { + String secretJson = createOAuthSecretJson(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + try (MockedStatic mockedStatic = Mockito.mockStatic(SnowflakeCredentialsProvider.class)) { + HttpURLConnection mockConnection = createMockHttpConnection(200, createTokenResponse()); + mockedStatic.when(() -> SnowflakeCredentialsProvider.getHttpURLConnection(anyString(), anyString(), anyString())) + .thenReturn(mockConnection); + + Map credentialMap = provider.getCredentialMap(); + + assertNotNull(credentialMap); + assertEquals(TEST_USERNAME, credentialMap.get("user")); + assertEquals(TEST_ACCESS_TOKEN, credentialMap.get("password")); + assertEquals("oauth", credentialMap.get("authenticator")); + } + } + } + + @Test + public void testGetCredentialMapWithExistingToken() + { + String secretJson = createOAuthSecretJsonWithExistingToken(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + Map credentialMap = provider.getCredentialMap(); + + assertNotNull(credentialMap); + assertEquals(TEST_USERNAME, credentialMap.get("user")); + assertEquals(TEST_ACCESS_TOKEN, credentialMap.get("password")); + assertEquals("oauth", credentialMap.get("authenticator")); + } + } + + @Test + public void testGetCredentialMapWithExpiredToken() throws Exception + { + String secretJson = createOAuthSecretJsonWithExpiredToken(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + try (MockedStatic mockedStatic = Mockito.mockStatic(SnowflakeCredentialsProvider.class)) { + HttpURLConnection mockConnection = createMockHttpConnection(200, createTokenResponse()); + mockedStatic.when(() -> SnowflakeCredentialsProvider.getHttpURLConnection(anyString(), anyString(), anyString())) + .thenReturn(mockConnection); + + Map credentialMap = provider.getCredentialMap(); + + assertNotNull(credentialMap); + assertEquals(TEST_USERNAME, credentialMap.get("user")); + assertEquals(TEST_ACCESS_TOKEN, credentialMap.get("password")); + assertEquals("oauth", credentialMap.get("authenticator")); + } + } + } + + @Test + public void testGetCredentialMapWithStandardCredentials() + { + String secretJson = createStandardSecretJson(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + Map credentialMap = provider.getCredentialMap(); + + assertNotNull(credentialMap); + assertEquals(TEST_USERNAME, credentialMap.get("user")); + assertEquals(TEST_PASSWORD, credentialMap.get("password")); + assertNull(credentialMap.get("authenticator")); + } + } + + @Test + public void testGetCredentialMapWithEmptyAuthCode() + { + String secretJson = createOAuthSecretJsonWithEmptyAuthCode(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + Map credentialMap = provider.getCredentialMap(); + + assertNotNull(credentialMap); + assertEquals(TEST_USERNAME, credentialMap.get("user")); + assertEquals(TEST_PASSWORD, credentialMap.get("password")); + assertNull(credentialMap.get("authenticator")); + } + } + + @Test + public void testGetCredentialMapWithMissingRequiredProperties() + { + String secretJson = createOAuthSecretJsonMissingClientId(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> { + provider.getCredentialMap(); + }); + + assertTrue(exception.getMessage().contains("Missing required property: client_id")); + } + } + + @Test + public void testGetCredentialMapWithHttpError() throws Exception + { + String secretJson = createOAuthSecretJson(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + try (MockedStatic mockedStatic = Mockito.mockStatic(SnowflakeCredentialsProvider.class)) { + HttpURLConnection mockConnection = createMockHttpConnection(400, "{\"error\":\"invalid_request\"}"); + mockedStatic.when(() -> SnowflakeCredentialsProvider.getHttpURLConnection(anyString(), anyString(), anyString())) + .thenReturn(mockConnection); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> { + provider.getCredentialMap(); + }); + + assertTrue(exception.getMessage().contains("Error retrieving Snowflake credential")); + } + } + } + + @Test + public void testGetCredentialMapWithInvalidJsonResponse() throws Exception + { + String secretJson = createOAuthSecretJson(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + try (MockedStatic mockedStatic = Mockito.mockStatic(SnowflakeCredentialsProvider.class)) { + HttpURLConnection mockConnection = createMockHttpConnection(200, "invalid json response"); + mockedStatic.when(() -> SnowflakeCredentialsProvider.getHttpURLConnection(anyString(), anyString(), anyString())) + .thenReturn(mockConnection); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> { + provider.getCredentialMap(); + }); + + assertTrue(exception.getMessage().contains("Error retrieving Snowflake credentials")); + } + } + } + + @Test + public void testGetCredentialMapWithIOException() throws Exception + { + String secretJson = createOAuthSecretJson(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + try (MockedStatic mockedStatic = Mockito.mockStatic(SnowflakeCredentialsProvider.class)) { + HttpURLConnection mockConnection = mock(HttpURLConnection.class); + when(mockConnection.getOutputStream()).thenThrow(new IOException("Connection failed")); + + mockedStatic.when(() -> SnowflakeCredentialsProvider.getHttpURLConnection(anyString(), anyString(), anyString())) + .thenReturn(mockConnection); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> { + provider.getCredentialMap(); + }); + + assertTrue(exception.getMessage().contains("Error retrieving Snowflake credentials")); + } + } + } + + @Test + public void testGetCredentialMapWithNullSecretString() + { + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(null); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> { + provider.getCredentialMap(); + }); + + assertTrue(exception.getMessage().contains("Error retrieving Snowflake credentials")); + } + } + + @Test + public void testGetCredentialMapWithInvalidSecretJson() + { + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn("invalid json"); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + RuntimeException exception = assertThrows(RuntimeException.class, () -> { + provider.getCredentialMap(); + }); + + assertTrue(exception.getMessage().contains("Error retrieving Snowflake credentials")); + } + } + + @Test + public void testGetCredentialMapWithPutSecretValueException() throws Exception + { + String secretJson = createOAuthSecretJson(); + + try (MockedConstruction mockedConstruction = mockConstruction(CachableSecretsManager.class, + (mock, context) -> { + when(mock.getSecret(TEST_SECRET_NAME)).thenReturn(secretJson); + when(mock.getSecretsManager()).thenReturn(mockSecretsClient); + })) { + + SnowflakeCredentialsProvider provider = new SnowflakeCredentialsProvider(TEST_SECRET_NAME, mockSecretsClient); + + try (MockedStatic mockedStatic = Mockito.mockStatic(SnowflakeCredentialsProvider.class)) { + HttpURLConnection mockConnection = createMockHttpConnection(200, createTokenResponse()); + mockedStatic.when(() -> SnowflakeCredentialsProvider.getHttpURLConnection(anyString(), anyString(), anyString())) + .thenReturn(mockConnection); + + // This should not throw an exception because the token request succeeds + // and the exception is only thrown when saving the token + Map credentialMap = provider.getCredentialMap(); + + assertNotNull(credentialMap); + assertEquals(TEST_USERNAME, credentialMap.get("user")); + assertEquals(TEST_ACCESS_TOKEN, credentialMap.get("password")); + assertEquals("oauth", credentialMap.get("authenticator")); + } + } + } + + // Helper methods + private String createOAuthSecretJson() + { + return new ObjectMapper().createObjectNode() + .put("client_id", TEST_CLIENT_ID) + .put("client_secret", TEST_CLIENT_SECRET) + .put("token_url", TEST_TOKEN_URL) + .put("redirect_uri", TEST_REDIRECT_URI) + .put("auth_code", TEST_AUTH_CODE) + .put("username", TEST_USERNAME) + .toString(); + } + + private String createOAuthSecretJsonWithExistingToken() + { + return new ObjectMapper().createObjectNode() + .put("client_id", TEST_CLIENT_ID) + .put("client_secret", TEST_CLIENT_SECRET) + .put("token_url", TEST_TOKEN_URL) + .put("redirect_uri", TEST_REDIRECT_URI) + .put("auth_code", TEST_AUTH_CODE) + .put("username", TEST_USERNAME) + .put("access_token", TEST_ACCESS_TOKEN) + .put("refresh_token", TEST_REFRESH_TOKEN) + .put("expires_in", "3600") + .put("fetched_at", String.valueOf(System.currentTimeMillis() / 1000 - 1000)) + .toString(); + } + + private String createOAuthSecretJsonWithExpiredToken() + { + return new ObjectMapper().createObjectNode() + .put("client_id", TEST_CLIENT_ID) + .put("client_secret", TEST_CLIENT_SECRET) + .put("token_url", TEST_TOKEN_URL) + .put("redirect_uri", TEST_REDIRECT_URI) + .put("auth_code", TEST_AUTH_CODE) + .put("username", TEST_USERNAME) + .put("access_token", "expired-token") + .put("refresh_token", TEST_REFRESH_TOKEN) + .put("expires_in", "3600") + .put("fetched_at", String.valueOf(System.currentTimeMillis() / 1000 - 4000)) + .toString(); + } + + private String createStandardSecretJson() + { + return new ObjectMapper().createObjectNode() + .put("username", TEST_USERNAME) + .put("password", TEST_PASSWORD) + .toString(); + } + + private String createOAuthSecretJsonWithEmptyAuthCode() + { + return new ObjectMapper().createObjectNode() + .put("client_id", TEST_CLIENT_ID) + .put("client_secret", TEST_CLIENT_SECRET) + .put("token_url", TEST_TOKEN_URL) + .put("redirect_uri", TEST_REDIRECT_URI) + .put("auth_code", "") + .put("username", TEST_USERNAME) + .put("password", TEST_PASSWORD) + .toString(); + } + + private String createOAuthSecretJsonMissingClientId() + { + return new ObjectMapper().createObjectNode() + .put("client_secret", TEST_CLIENT_SECRET) + .put("token_url", TEST_TOKEN_URL) + .put("redirect_uri", TEST_REDIRECT_URI) + .put("auth_code", TEST_AUTH_CODE) + .put("username", TEST_USERNAME) + .toString(); + } + + private String createTokenResponse() + { + return new JSONObject() + .put("access_token", TEST_ACCESS_TOKEN) + .put("token_type", "Bearer") + .put("expires_in", 3600) + .put("refresh_token", TEST_REFRESH_TOKEN) + .toString(); + } + + private HttpURLConnection createMockHttpConnection(int responseCode, String responseBody) throws IOException + { + HttpURLConnection mockConnection = mock(HttpURLConnection.class); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ByteArrayInputStream inputStream = new ByteArrayInputStream(responseBody.getBytes()); + + when(mockConnection.getOutputStream()).thenReturn(outputStream); + when(mockConnection.getResponseCode()).thenReturn(responseCode); + when(mockConnection.getInputStream()).thenReturn(inputStream); + when(mockConnection.getErrorStream()).thenReturn(null); + + return mockConnection; + } +}