diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java index 87d9e1da7..ad74f734f 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java @@ -104,9 +104,10 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration) Map presetUsers = configuration.getPresetUsers(); oauthManager = getOAuthManager(configuration); - formAuthManager = getFormAuthManager(configuration); authorizationManager = new AuthorizationManager(configuration.getAuthorization(), presetUsers); + formAuthManager = getFormAuthManager(configuration); + resourceSecurityDynamicFeature = getAuthFilter(configuration); backendStateConnectionManager = new BackendStateManager(); @@ -137,7 +138,7 @@ private LbFormAuthManager getFormAuthManager(HaGatewayConfiguration configuratio AuthenticationConfiguration authenticationConfiguration = configuration.getAuthentication(); if (authenticationConfiguration != null && authenticationConfiguration.getForm() != null) { return new LbFormAuthManager(authenticationConfiguration.getForm(), - configuration.getPresetUsers(), configuration.getPagePermissions()); + configuration.getPresetUsers(), configuration.getPagePermissions(), authorizationManager); } return null; } @@ -156,7 +157,7 @@ private ChainedAuthFilter getAuthenticationFilters(AuthenticationConfiguration c if (formAuthManager != null) { authFilters.add(new LbFilter( - new FormAuthenticator(formAuthManager, authorizationManager), + new FormAuthenticator(formAuthManager), authorizer, "Bearer", new LbUnauthorizedHandler(defaultType))); @@ -174,7 +175,7 @@ private ResourceSecurityDynamicFeature getAuthFilter(HaGatewayConfiguration conf { AuthorizationConfiguration authorizationConfig = configuration.getAuthorization(); Authorizer authorizer = (authorizationConfig != null) - ? new LbAuthorizer(authorizationConfig) : new NoopAuthorizer(); + ? new LbAuthorizer() : new NoopAuthorizer(); AuthenticationConfiguration authenticationConfig = configuration.getAuthentication(); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/AuthorizationManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/AuthorizationManager.java index ae89c862a..88b06cbd0 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/AuthorizationManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/AuthorizationManager.java @@ -13,10 +13,13 @@ */ package io.trino.gateway.ha.security; +import com.google.common.annotations.VisibleForTesting; import io.trino.gateway.ha.config.AuthorizationConfiguration; import io.trino.gateway.ha.config.LdapConfiguration; import io.trino.gateway.ha.config.UserConfiguration; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -24,6 +27,7 @@ public class AuthorizationManager { private final Map presetUsers; private final LbLdapClient lbLdapClient; + private final AuthorizationConfiguration authorizationConfiguration; public AuthorizationManager(AuthorizationConfiguration configuration, Map presetUsers) @@ -35,9 +39,42 @@ public AuthorizationManager(AuthorizationConfiguration configuration, else { lbLdapClient = null; } + this.authorizationConfiguration = configuration; } - public Optional getPrivileges(String username) + @VisibleForTesting + public AuthorizationManager(Map presetUsers, LbLdapClient lbLdapClient, AuthorizationConfiguration authorizationConfiguration) + { + this.presetUsers = presetUsers; + this.lbLdapClient = lbLdapClient; + this.authorizationConfiguration = authorizationConfiguration; + } + + public String getPrivileges(String username) + { + if (authorizationConfiguration == null) { + return "ADMIN_USER_API"; + } + Optional memberOf = getMemberOf(username); + List privileges = new ArrayList(); + + if (authorizationConfiguration.getAdmin() != null) { + memberOf.filter(m -> m.matches(authorizationConfiguration.getAdmin())).ifPresent(m -> privileges.add("ADMIN")); + } + if (authorizationConfiguration.getUser() != null) { + memberOf.filter(m -> m.matches(authorizationConfiguration.getUser())).ifPresent(m -> privileges.add("USER")); + } + if (authorizationConfiguration.getApi() != null) { + memberOf.filter(m -> m.matches(authorizationConfiguration.getApi())).ifPresent(m -> privileges.add("API")); + } + + if (privileges.isEmpty()) { + return ""; + } + return String.join("_", privileges); + } + + public Optional getMemberOf(String username) { //check the preset users String privs = ""; diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/FormAuthenticator.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/FormAuthenticator.java index a475516f9..2f8537823 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/FormAuthenticator.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/FormAuthenticator.java @@ -13,22 +13,23 @@ */ package io.trino.gateway.ha.security; +import com.auth0.jwt.interfaces.Claim; +import io.airlift.log.Logger; import io.trino.gateway.ha.security.util.AuthenticationException; import io.trino.gateway.ha.security.util.IdTokenAuthenticator; +import java.util.Map; import java.util.Optional; public class FormAuthenticator implements IdTokenAuthenticator { + private static final Logger log = Logger.get(FormAuthenticator.class); private final LbFormAuthManager formAuthManager; - private final AuthorizationManager authorizationManager; - public FormAuthenticator(LbFormAuthManager formAuthManager, - AuthorizationManager authorizationManager) + public FormAuthenticator(LbFormAuthManager formAuthManager) { this.formAuthManager = formAuthManager; - this.authorizationManager = authorizationManager; } /** @@ -43,11 +44,27 @@ public Optional authenticate(String idToken) throws AuthenticationException { String userIdField = formAuthManager.getUserIdField(); - return formAuthManager - .getClaimsFromIdToken(idToken) - .map(c -> c.get(userIdField)) - .map(Object::toString) - .map(s -> s.replace("\"", "")) - .map(sub -> new LbPrincipal(sub, authorizationManager.getPrivileges(sub))); + String privilegesField = formAuthManager.getPrivilegesField(); + + Map claims = null; + try { + claims = formAuthManager.getClaimsFromIdToken(idToken).orElseThrow(); + } + catch (Exception e) { + return Optional.empty(); + } + String userId = claims.get(userIdField).asString().replace("\"", ""); + + Claim claim = claims.get(privilegesField); + if (claim == null || claim.asString() == null) { + log.warn("No privileges found for user %s in idToken", userId); + throw new AuthenticationException("No privileges found for user " + userId + " in idToken"); + } + + String privileges = claim.asString(); + if (privileges == null) { + privileges = ""; + } + return Optional.of(new LbPrincipal(userId, privileges)); } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbAuthenticator.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbAuthenticator.java index b24064ed0..9424d3ec5 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbAuthenticator.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbAuthenticator.java @@ -70,7 +70,7 @@ public Optional authenticate(String idToken) privileges = Optional.ofNullable(role); } - return Optional.of(new LbPrincipal(userId, privileges)); + return Optional.of(new LbPrincipal(userId, privileges.orElse(""))); } return oauthManager .getClaimsFromIdToken(idToken) diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbAuthorizer.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbAuthorizer.java index 878eb6c24..fa871d26d 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbAuthorizer.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbAuthorizer.java @@ -14,7 +14,6 @@ package io.trino.gateway.ha.security; import io.airlift.log.Logger; -import io.trino.gateway.ha.config.AuthorizationConfiguration; import io.trino.gateway.ha.security.util.Authorizer; import jakarta.annotation.Nullable; import jakarta.ws.rs.container.ContainerRequestContext; @@ -23,11 +22,9 @@ public class LbAuthorizer implements Authorizer { private static final Logger log = Logger.get(LbAuthorizer.class); - private final AuthorizationConfiguration configuration; - public LbAuthorizer(AuthorizationConfiguration configuration) + public LbAuthorizer() { - this.configuration = configuration; } @Override @@ -37,23 +34,26 @@ public boolean authorize(LbPrincipal principal, { switch (role) { case "ADMIN": - log.info("User '%s' with memberOf(%s) was identified as ADMIN(%s)", - principal.getName(), principal.getMemberOf(), configuration.getAdmin()); - return principal.getMemberOf() - .filter(m -> m.matches(configuration.getAdmin())) - .isPresent(); + if (principal.getPrivileges().contains("ADMIN")) { + log.info("User '%s' with memberOf(%s) was identified as ADMIN", + principal.getName(), principal.getPrivileges()); + return true; + } + return false; case "USER": - log.info("User '%s' with memberOf(%s) identified as USER(%s)", - principal.getName(), principal.getMemberOf(), configuration.getUser()); - return principal.getMemberOf() - .filter(m -> m.matches(configuration.getUser())) - .isPresent(); + if (principal.getPrivileges().contains("USER")) { + log.info("User '%s' with memberOf(%s) identified as USER", + principal.getName(), principal.getPrivileges()); + return true; + } + return false; case "API": - log.info("User '%s' with memberOf(%s) identified as API(%s)", - principal.getName(), principal.getMemberOf(), configuration.getApi()); - return principal.getMemberOf() - .filter(m -> m.matches(configuration.getApi())) - .isPresent(); + if (principal.getPrivileges().contains("API")) { + log.info("User '%s' with memberOf(%s) identified as API", + principal.getName(), principal.getPrivileges()); + return true; + } + return false; default: log.warn("User '%s' with role %s has no regex match based on ldap search", principal.getName(), role); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbFormAuthManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbFormAuthManager.java index 34b5a89ec..35dfb0772 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbFormAuthManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbFormAuthManager.java @@ -45,10 +45,12 @@ public class LbFormAuthManager private final Map presetUsers; private final Map pagePermissions; private final LbLdapClient lbLdapClient; + private final AuthorizationManager authorizationManager; public LbFormAuthManager(FormAuthConfiguration configuration, Map presetUsers, - Map pagePermissions) + Map pagePermissions, + AuthorizationManager authorizationManager) { this.presetUsers = presetUsers; this.pagePermissions = pagePermissions.entrySet().stream() @@ -69,6 +71,7 @@ public LbFormAuthManager(FormAuthConfiguration configuration, else { lbLdapClient = null; } + this.authorizationManager = authorizationManager; } public String getUserIdField() @@ -76,6 +79,11 @@ public String getUserIdField() return "sub"; } + public String getPrivilegesField() + { + return "privileges"; + } + /** * Login API * @@ -128,6 +136,7 @@ private String getSelfSignedToken(String username) .withHeader(headers) .withIssuer(SessionCookie.SELF_ISSUER_ID) .withSubject(username) + .withClaim(getPrivilegesField(), authorizationManager.getPrivileges(username)) .sign(algorithm); } catch (JWTCreationException exception) { diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbPrincipal.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbPrincipal.java index a388809d1..1b6a7af7e 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbPrincipal.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbPrincipal.java @@ -15,18 +15,17 @@ import java.security.Principal; import java.util.Objects; -import java.util.Optional; public class LbPrincipal implements Principal { private final String name; - private final Optional memberOf; + private final String privileges; - public LbPrincipal(String name, Optional memberOf) + public LbPrincipal(String name, String privileges) { this.name = name; - this.memberOf = memberOf; + this.privileges = privileges; } @Override @@ -39,13 +38,13 @@ public boolean equals(Object o) return false; } LbPrincipal that = (LbPrincipal) o; - return name.equals(that.name) && memberOf.equals(that.memberOf); + return name.equals(that.name) && privileges.equals(that.privileges); } @Override public int hashCode() { - return Objects.hash(name, memberOf); + return Objects.hash(name, privileges); } @Override @@ -54,8 +53,8 @@ public String getName() return name; } - public Optional getMemberOf() + public String getPrivileges() { - return this.memberOf; + return this.privileges; } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/NoopFilter.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/NoopFilter.java index bf615bfa6..96e84272d 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/NoopFilter.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/NoopFilter.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.security.Principal; -import java.util.Optional; import static jakarta.ws.rs.Priorities.AUTHENTICATION; @@ -37,7 +36,7 @@ public void filter(final ContainerRequestContext requestContext) @Override public Principal getUserPrincipal() { - return new LbPrincipal("user", Optional.of("ADMIN_USER_API")); + return new LbPrincipal("user", "ADMIN_USER_API"); } @Override diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbAuthenticator.java b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbAuthenticator.java index d51b54ff9..7decf4ddc 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbAuthenticator.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbAuthenticator.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.log.Logger; +import io.trino.gateway.ha.config.AuthorizationConfiguration; import io.trino.gateway.ha.config.FormAuthConfiguration; import io.trino.gateway.ha.config.SelfSignKeyPairConfiguration; import io.trino.gateway.ha.config.UserConfiguration; @@ -48,7 +49,7 @@ final class TestLbAuthenticator private static final Logger log = Logger.get(TestLbAuthenticator.class); private static final String USER = "username"; - private static final Optional MEMBER_OF = Optional.of("PVFX_DATA_31"); + private static final String privileges = "ADMIN_USER_API"; private static final String ID_TOKEN = "TOKEN"; @Test @@ -63,7 +64,7 @@ void testAuthenticatorGetsPrincipal() Mockito .when(authorization.getPrivileges(USER)) - .thenReturn(MEMBER_OF); + .thenReturn(privileges); LbOAuthManager authentication = Mockito.mock(LbOAuthManager.class); Mockito @@ -74,7 +75,7 @@ void testAuthenticatorGetsPrincipal() .when(authentication.getUserIdField()) .thenReturn("sub"); - LbPrincipal principal = new LbPrincipal(USER, MEMBER_OF); + LbPrincipal principal = new LbPrincipal(USER, privileges); LbAuthenticator lbAuth = new LbAuthenticator(authentication, authorization); @@ -104,7 +105,7 @@ void testAuthorizationListFromOAuthField() LbAuthenticator lbAuthenticator = new LbAuthenticator(oAuthManager, Mockito.mock(AuthorizationManager.class)); Optional principal = lbAuthenticator.authenticate(ID_TOKEN); - assertThat(principal).hasValue(new LbPrincipal(USER, Optional.of("admin_api_user"))); + assertThat(principal).hasValue(new LbPrincipal(USER, "admin_api_user")); } @Test @@ -131,7 +132,7 @@ void testAuthorizationFieldFromOAuthField() LbAuthenticator lbAuthenticator = new LbAuthenticator(oAuthManager, Mockito.mock(AuthorizationManager.class)); Optional principal = lbAuthenticator.authenticate(ID_TOKEN); - assertThat(principal).hasValue(new LbPrincipal(USER, Optional.of("admin_api"))); + assertThat(principal).hasValue(new LbPrincipal(USER, "admin_api")); } @Test @@ -184,7 +185,8 @@ void testPresetUsers() "user1", new UserConfiguration("priv1, priv2", "pass1"), "user2", new UserConfiguration("priv2, priv2", "pass2")); - LbFormAuthManager authentication = new LbFormAuthManager(null, presetUsers, new HashMap<>()); + AuthorizationManager authorizationManager = new AuthorizationManager(new AuthorizationConfiguration(), presetUsers); + LbFormAuthManager authentication = new LbFormAuthManager(null, presetUsers, new HashMap<>(), authorizationManager); assertThat(authentication.authenticate(new BasicCredentials("user1", "pass1"))) .isTrue(); @@ -198,7 +200,7 @@ void testPresetUsers() void testNoLdapNoPresetUsers() throws Exception { - LbFormAuthManager authentication = new LbFormAuthManager(null, null, ImmutableMap.of()); + LbFormAuthManager authentication = new LbFormAuthManager(null, null, ImmutableMap.of(), null); assertThat(authentication.authenticate(new BasicCredentials("user1", "pass1"))) .isFalse(); } @@ -207,7 +209,7 @@ void testNoLdapNoPresetUsers() void testWrongLdapConfig() throws Exception { - LbFormAuthManager authentication = new LbFormAuthManager(null, null, ImmutableMap.of()); + LbFormAuthManager authentication = new LbFormAuthManager(null, null, ImmutableMap.of(), null); assertThat(authentication.authenticate(new BasicCredentials("user1", "pass1"))) .isFalse(); } @@ -219,7 +221,7 @@ void testNullInPagePermission() Map pagePermission = new HashMap<>(); pagePermission.put("user", null); - LbFormAuthManager authentication = new LbFormAuthManager(null, presetUsers, unmodifiableMap(pagePermission)); + LbFormAuthManager authentication = new LbFormAuthManager(null, presetUsers, unmodifiableMap(pagePermission), null); assertThat(authentication.authenticate(new BasicCredentials("user1", "pass1"))) .isTrue(); } @@ -249,7 +251,8 @@ void testLoginForm() "user1", new UserConfiguration("priv1, priv2", "pass1"), "user2", new UserConfiguration("priv2, priv2", "pass2")); - LbFormAuthManager lbFormAuthManager = new LbFormAuthManager(formAuthConfig, presetUsers, new HashMap<>()); + AuthorizationManager authorizationManager = new AuthorizationManager(null, presetUsers); + LbFormAuthManager lbFormAuthManager = new LbFormAuthManager(formAuthConfig, presetUsers, new HashMap<>(), authorizationManager); RestLoginRequest restLoginRequest = new RestLoginRequest("user1", "pass1"); Result r = lbFormAuthManager.processRESTLogin(restLoginRequest); assertThat(Result.isSuccess(r)).isTrue(); diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbAuthorizer.java b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbAuthorizer.java index 6459f3bfb..bb2b19030 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbAuthorizer.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbAuthorizer.java @@ -13,138 +13,103 @@ */ package io.trino.gateway.ha.security; -import io.airlift.log.Logger; -import io.trino.gateway.ha.config.AuthorizationConfiguration; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; -import java.util.Optional; -import java.util.regex.PatternSyntaxException; - import static org.assertj.core.api.Assertions.assertThat; final class TestLbAuthorizer { - private static final Logger log = Logger.get(TestLbAuthorizer.class); - private static final String USER = "username"; private static final String ADMIN_ROLE = "ADMIN"; private static final String USER_ROLE = "USER"; private static final String API_ROLE = "API"; private static final String UNKNOWN_ROLE = "UNKNOWN"; - private static final String PVFX_DATA_31 = "PVFX_DATA_31"; - private static LbPrincipal principal; private static LbAuthorizer authorizer; - private static AuthorizationConfiguration configuration; @BeforeAll - static void setup() + public static void setup() { - configuration = new AuthorizationConfiguration(); - principal = new LbPrincipal(USER, Optional.of(PVFX_DATA_31)); + authorizer = new LbAuthorizer(); } - static void configureRole(String regex, String role) - { - if (role.equalsIgnoreCase(ADMIN_ROLE)) { - configuration.setAdmin(regex); - authorizer = new LbAuthorizer(configuration); - } - if (role.equalsIgnoreCase(USER_ROLE)) { - configuration.setUser(regex); - authorizer = new LbAuthorizer(configuration); - } - if (role.equalsIgnoreCase(API_ROLE)) { - configuration.setApi(regex); - authorizer = new LbAuthorizer(configuration); - } - } - - static void assertMatch(String role) + static void assertMatch(LbPrincipal principal, String role) { assertThat(authorizer.authorize(principal, role, null)).isTrue(); } - static void assertNotMatch(String role) + static void assertNotMatch(LbPrincipal principal, String role) { assertThat(authorizer.authorize(principal, role, null)).isFalse(); } - static void assertBadPattern(String role) + @Test + public void testBasic() { - log.info("Configured bad regex pattern for role [%s]", role); - try { - assertNotMatch(role); - } - catch (PatternSyntaxException e) { - log.info("Failed to compile ==> OKAY"); - } + LbPrincipal principal = new LbPrincipal(USER, "ADMIN_USER_API"); + assertMatch(principal, ADMIN_ROLE); + assertMatch(principal, USER_ROLE); + assertMatch(principal, API_ROLE); + assertNotMatch(principal, UNKNOWN_ROLE); // UNKNOWN ROLE should always return FALSE } @Test - void testBasic() + public void testMultiplePrivileges() { - configureRole(PVFX_DATA_31, ADMIN_ROLE); - assertMatch(ADMIN_ROLE); - - configureRole(PVFX_DATA_31, UNKNOWN_ROLE); - assertNotMatch(UNKNOWN_ROLE); // UNKNOWN ROLE should always return FALSE - - configureRole("PVFX", USER_ROLE); - assertNotMatch(USER_ROLE); - - configureRole("DATA", API_ROLE); - assertNotMatch(API_ROLE); - - configureRole("31", ADMIN_ROLE); - assertNotMatch(ADMIN_ROLE); + LbPrincipal principal = new LbPrincipal(USER, "ADMIN_USER"); + assertMatch(principal, ADMIN_ROLE); + assertMatch(principal, USER_ROLE); + assertNotMatch(principal, API_ROLE); + assertNotMatch(principal, UNKNOWN_ROLE); } @Test - void testZeroOrMoreCharacters() + public void testUserApiPrivileges() { - configureRole("PVFX(.*)", ADMIN_ROLE); - assertMatch(ADMIN_ROLE); - - configureRole("(?i)pvfx(.*)", USER_ROLE); - assertMatch(USER_ROLE); - - configureRole("(.*)", API_ROLE); - assertMatch(API_ROLE); - - configureRole("PVFX_DATA_31(.*)", ADMIN_ROLE); - assertMatch(ADMIN_ROLE); - - configureRole("(.*)_31", USER_ROLE); - assertMatch(USER_ROLE); - - configureRole("(.*)DATA(.*)", API_ROLE); - assertMatch(API_ROLE); - - configureRole("^.+$", ADMIN_ROLE); - assertMatch(ADMIN_ROLE); - - configureRole("^.+$", UNKNOWN_ROLE); - assertNotMatch(UNKNOWN_ROLE); // UNKNOWN ROLE should always return FALSE - - configureRole("(.*)DATA", USER_ROLE); - assertNotMatch(USER_ROLE); + LbPrincipal principal = new LbPrincipal(USER, "USER_API"); + assertNotMatch(principal, ADMIN_ROLE); + assertMatch(principal, USER_ROLE); + assertMatch(principal, API_ROLE); + assertNotMatch(principal, UNKNOWN_ROLE); + } - configureRole("PVFX__(.*)", API_ROLE); - assertNotMatch(API_ROLE); + @Test + public void testAdminOnlyPrivilege() + { + LbPrincipal principal = new LbPrincipal(USER, "ADMIN"); + assertMatch(principal, ADMIN_ROLE); + assertNotMatch(principal, USER_ROLE); + assertNotMatch(principal, API_ROLE); + assertNotMatch(principal, UNKNOWN_ROLE); } @Test - void testBadPatterns() - throws Exception + public void testUserOnlyPrivilege() { - configureRole("^[a-zA--Z0-9_]+$", ADMIN_ROLE); // bad range - assertBadPattern(ADMIN_ROLE); + LbPrincipal principal = new LbPrincipal(USER, "USER"); + assertNotMatch(principal, ADMIN_ROLE); + assertMatch(principal, USER_ROLE); + assertNotMatch(principal, API_ROLE); + assertNotMatch(principal, UNKNOWN_ROLE); + } - configureRole("^[a-zA-Z0-9_*$", USER_ROLE); // missing ] - assertBadPattern(USER_ROLE); + @Test + public void testApiOnlyPrivilege() + { + LbPrincipal principal = new LbPrincipal(USER, "API"); + assertNotMatch(principal, ADMIN_ROLE); + assertNotMatch(principal, USER_ROLE); + assertMatch(principal, API_ROLE); + assertNotMatch(principal, UNKNOWN_ROLE); + } - configureRole("^[a-zA-Z0-9_]+$\\", API_ROLE); // nothing to escape - assertBadPattern(API_ROLE); + @Test + public void testNoPrivileges() + { + LbPrincipal principal = new LbPrincipal(USER, ""); + assertNotMatch(principal, ADMIN_ROLE); + assertNotMatch(principal, USER_ROLE); + assertNotMatch(principal, API_ROLE); + assertNotMatch(principal, UNKNOWN_ROLE); } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbFilter.java b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbFilter.java index 20611a84f..99f60784a 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbFilter.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestLbFilter.java @@ -27,6 +27,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mockito; +import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -37,12 +38,12 @@ final class TestLbFilter { private static final String USER = "username"; - private static final Optional MEMBER_OF = Optional.of("PVFX_DATA_31"); + private static final String MEMBER_OF = "PVFX_DATA_31"; private static final String ID_TOKEN = "TOKEN"; private LbOAuthManager oauthManager; - private AuthorizationManager authorizationManager; private ContainerRequestContext requestContext; + private LbAuthorizer authorizer; @BeforeAll void setup() @@ -59,14 +60,9 @@ void setup() .thenReturn(Optional.of(Map.of("sub", claim))); Mockito.when(oauthManager.getUserIdField()).thenReturn("sub"); - // Set authorization manager with membership - authorizationManager = Mockito.mock(AuthorizationManager.class); - Mockito - .when(authorizationManager.getPrivileges(USER)) - .thenReturn(MEMBER_OF); - // Request context for the auth filter requestContext = Mockito.mock(ContainerRequestContext.class); + authorizer = new LbAuthorizer(); } @Test @@ -75,7 +71,7 @@ void testSuccessfulCookieAuthentication() { AuthorizationConfiguration configuration = new AuthorizationConfiguration(); configuration.setAdmin("NO_MEMBER"); - configuration.setUser(MEMBER_OF.orElseThrow()); + configuration.setUser(MEMBER_OF); Mockito .when(requestContext.getCookies()) @@ -86,11 +82,11 @@ void testSuccessfulCookieAuthentication() .when(requestContext.getHeaders()) .thenReturn(new MultivaluedHashMap()); + AuthorizationManager authorizationManager = getAuthorizationManager(configuration); LbAuthenticator authenticator = new LbAuthenticator( oauthManager, authorizationManager); - LbAuthorizer authorizer = new LbAuthorizer(configuration); LbFilter lbFilter = new LbFilter( authenticator, authorizer, @@ -116,8 +112,8 @@ void testSuccessfulHeaderAuthentication() throws Exception { AuthorizationConfiguration configuration = new AuthorizationConfiguration(); - configuration.setAdmin(MEMBER_OF.orElseThrow()); - configuration.setUser(MEMBER_OF.orElseThrow()); + configuration.setAdmin(MEMBER_OF); + configuration.setUser(MEMBER_OF); MultivaluedHashMap headers = new MultivaluedHashMap<>(); headers.addFirst(HttpHeaders.AUTHORIZATION, "Bearer " + ID_TOKEN); @@ -128,10 +124,11 @@ void testSuccessfulHeaderAuthentication() Mockito .when(requestContext.getHeaders()) .thenReturn(headers); + + AuthorizationManager authorizationManager = getAuthorizationManager(configuration); LbAuthenticator authenticator = new LbAuthenticator( oauthManager, authorizationManager); - LbAuthorizer authorizer = new LbAuthorizer(configuration); LbFilter lbFilter = new LbFilter( authenticator, authorizer, @@ -165,10 +162,11 @@ void testMissingAuthenticationToken() .thenReturn(Map.of()); Mockito.when(requestContext.getHeaders()) .thenReturn(headers); + + AuthorizationManager authorizationManager = getAuthorizationManager(configuration); LbAuthenticator authenticator = new LbAuthenticator( oauthManager, authorizationManager); - LbAuthorizer authorizer = new LbAuthorizer(configuration); LbFilter lbFilter = new LbFilter( authenticator, authorizer, @@ -179,4 +177,12 @@ void testMissingAuthenticationToken() lbFilter.filter(requestContext); }).isInstanceOf(WebApplicationException.class); } + + private AuthorizationManager getAuthorizationManager(AuthorizationConfiguration configuration) + { + LbLdapClient lbLdapClient = Mockito.mock(LbLdapClient.class); + Mockito.when(lbLdapClient.getMemberOf(USER)).thenReturn(MEMBER_OF); + + return new AuthorizationManager(new HashMap<>(), lbLdapClient, configuration); + } }