2222import com .amazonaws .athena .connector .credentials .CredentialsProvider ;
2323import com .amazonaws .athena .connector .credentials .DefaultCredentials ;
2424import com .amazonaws .athena .connector .lambda .security .CachableSecretsManager ;
25+ import com .amazonaws .athena .connectors .snowflake .utils .SnowflakeAuthType ;
26+ import com .amazonaws .athena .connectors .snowflake .utils .SnowflakeAuthUtils ;
2527import com .fasterxml .jackson .databind .ObjectMapper ;
28+ import com .fasterxml .jackson .databind .node .ObjectNode ;
2629import com .google .common .annotations .VisibleForTesting ;
27- import org .json .JSONObject ;
2830import org .slf4j .Logger ;
2931import org .slf4j .LoggerFactory ;
3032import software .amazon .awssdk .services .secretsmanager .SecretsManagerClient ;
4345import java .util .HashMap ;
4446import java .util .Map ;
4547
48+ import static com .amazonaws .athena .connectors .snowflake .utils .SnowflakeAuthType .OAUTH ;
49+ import static com .amazonaws .athena .connectors .snowflake .utils .SnowflakeAuthUtils .getUsername ;
50+ import static com .amazonaws .athena .connectors .snowflake .utils .SnowflakeAuthUtils .validateCredentials ;
51+
4652/**
47- * Snowflake OAuth credentials provider that manages OAuth token lifecycle .
48- * This provider handles token refresh, expiration , and provides credential properties
49- * for Snowflake OAuth connections .
53+ * Snowflake credentials provider that manages multiple authentication methods .
54+ * This provider handles OAuth token lifecycle, key-pair authentication , and password authentication.
55+ * Authentication method is automatically determined based on the secret contents .
5056 */
5157public class SnowflakeCredentialsProvider implements CredentialsProvider
5258{
@@ -56,10 +62,6 @@ public class SnowflakeCredentialsProvider implements CredentialsProvider
5662 public static final String FETCHED_AT = "fetched_at" ;
5763 public static final String REFRESH_TOKEN = "refresh_token" ;
5864 public static final String EXPIRES_IN = "expires_in" ;
59- public static final String USERNAME = "username" ;
60- public static final String PASSWORD = "password" ;
61- public static final String USER = "user" ;
62-
6365 private final String oauthSecretName ;
6466 private final CachableSecretsManager secretsManager ;
6567 private final ObjectMapper objectMapper ;
@@ -82,8 +84,8 @@ public DefaultCredentials getCredential()
8284 {
8385 Map <String , String > credentialMap = getCredentialMap ();
8486 return new DefaultCredentials (
85- credentialMap .get (USER ),
86- credentialMap .get (PASSWORD )
87+ credentialMap .get (SnowflakeConstants . USER ),
88+ credentialMap .get (SnowflakeConstants . PASSWORD )
8789 );
8890 }
8991
@@ -92,32 +94,71 @@ public Map<String, String> getCredentialMap()
9294 {
9395 try {
9496 String secretString = secretsManager .getSecret (oauthSecretName );
95- Map <String , String > oauthConfig = objectMapper .readValue (secretString , Map .class );
97+ Map <String , String > secretMap = objectMapper .readValue (secretString , Map .class );
9698
97- if (oauthConfig .containsKey (SnowflakeConstants .AUTH_CODE ) && !oauthConfig .get (SnowflakeConstants .AUTH_CODE ).isEmpty ()) {
98- // OAuth flow
99- String accessToken = fetchAccessTokenFromSecret (oauthConfig );
100-
101- Map <String , String > credentialMap = new HashMap <>();
102- credentialMap .put (USER , oauthConfig .get (USERNAME ));
103- credentialMap .put (PASSWORD , accessToken );
104- credentialMap .put ("authenticator" , "oauth" );
105-
106- return credentialMap ;
107- }
108- else {
109- // Fallback to standard credentials
110- return Map .of (
111- USER , oauthConfig .get (USERNAME ),
112- PASSWORD , oauthConfig .get (PASSWORD )
113- );
99+ // Determine authentication type based on secret contents
100+ SnowflakeAuthType authType = SnowflakeAuthUtils .determineAuthType (secretMap );
101+ // Validate credentials once after determining auth type
102+ validateCredentials (secretMap , authType );
103+ switch (authType ) {
104+ case SNOWFLAKE_JWT :
105+ // Key-pair authentication
106+ return handleKeyPairAuthentication (secretMap );
107+ case OAUTH :
108+ // OAuth flow
109+ return handleOAuthAuthentication (secretMap );
110+ case SNOWFLAKE :
111+ default :
112+ // Password authentication (backward compatible)
113+ return handlePasswordAuthentication (secretMap );
114114 }
115115 }
116116 catch (Exception ex ) {
117117 throw new RuntimeException ("Error retrieving Snowflake credentials: " + ex .getMessage (), ex );
118118 }
119119 }
120120
121+ /**
122+ * Handles key-pair authentication.
123+ */
124+ private Map <String , String > handleKeyPairAuthentication (Map <String , String > oauthConfig )
125+ {
126+ Map <String , String > credentialMap = new HashMap <>();
127+ String username = getUsername (oauthConfig );
128+ credentialMap .put (USER , username );
129+ credentialMap .put (SnowflakeConstants .PEM_PRIVATE_KEY , oauthConfig .get (SnowflakeConstants .PEM_PRIVATE_KEY ));
130+ credentialMap .put (SnowflakeConstants .PEM_PRIVATE_KEY_PASSPHRASE , oauthConfig .get (SnowflakeConstants .PEM_PRIVATE_KEY_PASSPHRASE ));
131+ LOGGER .debug ("Using key-pair authentication for user: {}" , username );
132+ return credentialMap ;
133+ }
134+
135+ /**
136+ * Handles OAuth authentication.
137+ */
138+ private Map <String , String > handleOAuthAuthentication (Map <String , String > oauthConfig ) throws Exception
139+ {
140+ String accessToken = fetchAccessTokenFromSecret (oauthConfig );
141+ Map <String , String > credentialMap = new HashMap <>();
142+ String username = getUsername (oauthConfig );
143+ credentialMap .put (SnowflakeConstants .USER , username );
144+ credentialMap .put (SnowflakeConstants .PASSWORD , accessToken );
145+ credentialMap .put (SnowflakeConstants .AUTHENTICATOR , OAUTH .getValue ());
146+ LOGGER .debug ("Using OAuth authentication for user: {}" , username );
147+ return credentialMap ;
148+ }
149+
150+ /**
151+ * Handles password authentication (backward compatible).
152+ */
153+ private Map <String , String > handlePasswordAuthentication (Map <String , String > oauthConfig )
154+ {
155+ Map <String , String > credentialMap = new HashMap <>();
156+ credentialMap .put (SnowflakeConstants .USER , getUsername (oauthConfig ));
157+ credentialMap .put (SnowflakeConstants .PASSWORD , oauthConfig .get (SnowflakeConstants .PASSWORD ));
158+ LOGGER .debug ("Using password authentication for user: {}" , getUsername (oauthConfig ));
159+ return credentialMap ;
160+ }
161+
121162 private String loadTokenFromSecretsManager (Map <String , String > oauthConfig )
122163 {
123164 if (oauthConfig .containsKey (ACCESS_TOKEN )) {
@@ -126,20 +167,27 @@ private String loadTokenFromSecretsManager(Map<String, String> oauthConfig)
126167 return null ;
127168 }
128169
129- private void saveTokenToSecretsManager (JSONObject tokenJson , Map <String , String > oauthConfig )
170+ private void saveTokenToSecretsManager (ObjectNode tokenJson , Map <String , String > oauthConfig )
130171 {
131172 // Update token related fields
132173 tokenJson .put (FETCHED_AT , System .currentTimeMillis () / 1000 );
133- oauthConfig .put (ACCESS_TOKEN , tokenJson .getString (ACCESS_TOKEN ));
134- oauthConfig .put (REFRESH_TOKEN , tokenJson .getString (REFRESH_TOKEN ));
135- oauthConfig .put (EXPIRES_IN , String .valueOf (tokenJson .getInt (EXPIRES_IN )));
136- oauthConfig .put (FETCHED_AT , String .valueOf (tokenJson .getLong (FETCHED_AT )));
174+ oauthConfig .put (ACCESS_TOKEN , tokenJson .get (ACCESS_TOKEN ). asText ( ));
175+ oauthConfig .put (REFRESH_TOKEN , tokenJson .get (REFRESH_TOKEN ). asText ( ));
176+ oauthConfig .put (EXPIRES_IN , String .valueOf (tokenJson .get (EXPIRES_IN ). asInt ( )));
177+ oauthConfig .put (FETCHED_AT , String .valueOf (tokenJson .get (FETCHED_AT ). asLong ( )));
137178
138179 // Save updated secret
139- secretsManager .getSecretsManager ().putSecretValue (builder -> builder
140- .secretId (this .oauthSecretName )
141- .secretString (String .valueOf (new JSONObject (oauthConfig )))
142- .build ());
180+ try {
181+ String updatedSecretString = objectMapper .writeValueAsString (oauthConfig );
182+ secretsManager .getSecretsManager ().putSecretValue (builder -> builder
183+ .secretId (this .oauthSecretName )
184+ .secretString (updatedSecretString )
185+ .build ());
186+ }
187+ catch (Exception e ) {
188+ LOGGER .error ("Failed to save updated secret: " , e );
189+ throw new RuntimeException ("Failed to save updated secret: " , e );
190+ }
143191 }
144192
145193 private String fetchAccessTokenFromSecret (Map <String , String > oauthConfig ) throws Exception
@@ -155,9 +203,9 @@ private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throw
155203
156204 if (accessToken == null ) {
157205 LOGGER .debug ("First time auth. Using authorization_code..." );
158- JSONObject tokenJson = getTokenFromAuthCode (authCode , redirectUri , tokenEndpoint , clientId , clientSecret );
206+ ObjectNode tokenJson = getTokenFromAuthCode (authCode , redirectUri , tokenEndpoint , clientId , clientSecret );
159207 saveTokenToSecretsManager (tokenJson , oauthConfig );
160- accessToken = tokenJson .getString (ACCESS_TOKEN );
208+ accessToken = tokenJson .get (ACCESS_TOKEN ). asText ( );
161209 }
162210 else {
163211 long expiresIn = Long .parseLong (oauthConfig .get (EXPIRES_IN ));
@@ -169,16 +217,16 @@ private String fetchAccessTokenFromSecret(Map<String, String> oauthConfig) throw
169217 }
170218 else {
171219 LOGGER .debug ("Access token expired. Using refresh_token..." );
172- JSONObject refreshed = refreshAccessToken (oauthConfig .get (REFRESH_TOKEN ), tokenEndpoint , clientId , clientSecret );
220+ ObjectNode refreshed = refreshAccessToken (oauthConfig .get (REFRESH_TOKEN ), tokenEndpoint , clientId , clientSecret );
173221 refreshed .put (REFRESH_TOKEN , oauthConfig .get (REFRESH_TOKEN ));
174222 saveTokenToSecretsManager (refreshed , oauthConfig );
175- accessToken = refreshed .getString (ACCESS_TOKEN );
223+ accessToken = refreshed .get (ACCESS_TOKEN ). asText ( );
176224 }
177225 }
178226 return accessToken ;
179227 }
180228
181- private JSONObject getTokenFromAuthCode (String authCode , String redirectUri , String tokenEndpoint , String clientId , String clientSecret ) throws Exception
229+ private ObjectNode getTokenFromAuthCode (String authCode , String redirectUri , String tokenEndpoint , String clientId , String clientSecret ) throws Exception
182230 {
183231 String body = "grant_type=authorization_code"
184232 + "&code=" + authCode
@@ -187,15 +235,15 @@ private JSONObject getTokenFromAuthCode(String authCode, String redirectUri, Str
187235 return requestToken (body , tokenEndpoint , clientId , clientSecret );
188236 }
189237
190- private JSONObject refreshAccessToken (String refreshToken , String tokenEndpoint , String clientId , String clientSecret ) throws Exception
238+ private ObjectNode refreshAccessToken (String refreshToken , String tokenEndpoint , String clientId , String clientSecret ) throws Exception
191239 {
192240 String body = "grant_type=refresh_token"
193241 + "&refresh_token=" + URLEncoder .encode (refreshToken , StandardCharsets .UTF_8 );
194242
195243 return requestToken (body , tokenEndpoint , clientId , clientSecret );
196244 }
197245
198- private JSONObject requestToken (String requestBody , String tokenEndpoint , String clientId , String clientSecret ) throws Exception
246+ private ObjectNode requestToken (String requestBody , String tokenEndpoint , String clientId , String clientSecret ) throws Exception
199247 {
200248 HttpURLConnection conn = getHttpURLConnection (tokenEndpoint , clientId , clientSecret );
201249
@@ -215,7 +263,7 @@ private JSONObject requestToken(String requestBody, String tokenEndpoint, String
215263 throw new RuntimeException ("Failed: " + responseCode + " - " + response );
216264 }
217265
218- JSONObject tokenJson = new JSONObject (response );
266+ ObjectNode tokenJson = objectMapper . readValue (response , ObjectNode . class );
219267 tokenJson .put (FETCHED_AT , System .currentTimeMillis () / 1000 );
220268 return tokenJson ;
221269 }
0 commit comments