Skip to content

Commit 28caef2

Browse files
committed
Add request filter for early user and query parsing
1 parent d6fce93 commit 28caef2

20 files changed

+780
-229
lines changed

gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import io.trino.gateway.ha.router.RoutingRulesManager;
4040
import io.trino.gateway.ha.router.StochasticRoutingManager;
4141
import io.trino.gateway.ha.security.AuthorizedExceptionMapper;
42+
import io.trino.gateway.ha.security.QueryMetadataParser;
43+
import io.trino.gateway.ha.security.QueryUserInfoParser;
4244
import io.trino.gateway.proxyserver.ForProxy;
4345
import io.trino.gateway.proxyserver.ProxyRequestHandler;
4446
import io.trino.gateway.proxyserver.RouteToBackendResource;
@@ -187,6 +189,8 @@ private static void registerProxyResources(Binder binder)
187189
{
188190
jaxrsBinder(binder).bind(RouteToBackendResource.class);
189191
jaxrsBinder(binder).bind(RouterPreMatchContainerRequestFilter.class);
192+
jaxrsBinder(binder).bind(QueryUserInfoParser.class);
193+
jaxrsBinder(binder).bind(QueryMetadataParser.class);
190194
jaxrsBinder(binder).bind(ProxyRequestHandler.class);
191195
httpClientBinder(binder).bindHttpClient("proxy", ForProxy.class);
192196
httpClientBinder(binder).bindHttpClient("monitor", ForMonitor.class);

gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,17 @@
2020
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Map;
23-
23+
import java.util.regex.Pattern;
24+
25+
import static com.google.common.collect.ImmutableList.toImmutableList;
26+
import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH;
27+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
28+
import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH;
29+
import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH;
30+
import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH;
31+
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
2432
import static io.trino.gateway.ha.handler.HttpUtils.V1_STATEMENT_PATH;
33+
import static java.util.Objects.requireNonNull;
2534

2635
public class HaGatewayConfiguration
2736
{
@@ -289,6 +298,19 @@ private void validateStatementPath(String statementPath, List<String> statementP
289298
}
290299
}
291300

301+
public boolean isPathWhiteListed(String path)
302+
{
303+
List<Pattern> extraWhitelistPaths = requireNonNull(this.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
304+
return statementPaths.stream().anyMatch(path::startsWith)
305+
|| path.startsWith(V1_QUERY_PATH)
306+
|| path.startsWith(TRINO_UI_PATH)
307+
|| path.startsWith(V1_INFO_PATH)
308+
|| path.startsWith(V1_NODE_PATH)
309+
|| path.startsWith(UI_API_STATS_PATH)
310+
|| path.startsWith(OAUTH_PATH)
311+
|| extraWhitelistPaths.stream().anyMatch(pattern -> pattern.matcher(path).matches());
312+
}
313+
292314
public static class HaGatewayConfigurationException
293315
extends RuntimeException
294316
{

gateway-ha/src/main/java/io/trino/gateway/ha/handler/HttpUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ public class HttpUtils
2525
public static final String TRINO_UI_PATH = "/ui";
2626
public static final String OAUTH_PATH = "/oauth2";
2727
public static final String USER_HEADER = "X-Trino-User";
28+
public static final String TRINO_REQUEST_USER = "trinoRequestUser";
29+
public static final String TRINO_QUERY_PROPERTIES = "trinoQueryProperties";
2830

2931
private HttpUtils() {}
3032
}

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.util.regex.Pattern;
2929

3030
import static com.google.common.base.Strings.isNullOrEmpty;
31+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
3132
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
3233
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
3334
import static java.nio.charset.StandardCharsets.UTF_8;
@@ -75,7 +76,7 @@ public static Optional<String> extractQueryIdIfPresent(
7576
throw new RuntimeException("Error reading request body", e);
7677
}
7778
if (!isNullOrEmpty(queryText) && queryText.toLowerCase(ENGLISH).contains("kill_query")) {
78-
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
79+
TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
7980
return trinoQueryProperties.getQueryId();
8081
}
8182
return Optional.empty();

gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler;
4747
import static io.airlift.http.client.Request.Builder.preparePost;
4848
import static io.airlift.json.JsonCodec.jsonCodec;
49+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
50+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER;
4951
import static java.util.Collections.list;
5052
import static java.util.Objects.requireNonNull;
5153

@@ -58,7 +60,6 @@ public class ExternalRoutingGroupSelector
5860
private final boolean propagateErrors;
5961
private final HttpClient httpClient;
6062
private final RequestAnalyzerConfig requestAnalyzerConfig;
61-
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
6263
private static final JsonCodec<RoutingGroupExternalBody> ROUTING_GROUP_EXTERNAL_BODY_JSON_CODEC = jsonCodec(RoutingGroupExternalBody.class);
6364
private static final JsonResponseHandler<ExternalRouterResponse> ROUTING_GROUP_EXTERNAL_RESPONSE_JSON_RESPONSE_HANDLER =
6465
createJsonResponseHandler(jsonCodec(ExternalRouterResponse.class));
@@ -74,7 +75,6 @@ public class ExternalRoutingGroupSelector
7475
propagateErrors = rulesExternalConfiguration.isPropagateErrors();
7576

7677
this.requestAnalyzerConfig = requestAnalyzerConfig;
77-
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig);
7878
try {
7979
this.uri = new URI(requireNonNull(rulesExternalConfiguration.getUrlPath(),
8080
"Invalid URL provided, using routing group header as default."));
@@ -143,8 +143,8 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request)
143143
TrinoQueryProperties trinoQueryProperties = null;
144144
TrinoRequestUser trinoRequestUser = null;
145145
if (requestAnalyzerConfig.isAnalyzeRequest()) {
146-
trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig.isClientsUseV2Format(), requestAnalyzerConfig.getMaxBodySize());
147-
trinoRequestUser = trinoRequestUserProvider.getInstance(request);
146+
trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
147+
trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER);
148148
}
149149

150150
return new RoutingGroupExternalBody(

gateway-ha/src/main/java/io/trino/gateway/ha/router/FileBasedRoutingGroupSelector.java

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import java.util.Map;
3434

3535
import static com.google.common.base.Suppliers.memoizeWithExpiration;
36+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
37+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER;
3638
import static java.nio.charset.StandardCharsets.UTF_8;
3739
import static java.util.Collections.sort;
3840

@@ -46,16 +48,10 @@ public class FileBasedRoutingGroupSelector
4648

4749
private final Supplier<List<RoutingRule>> rules;
4850
private final boolean analyzeRequest;
49-
private final boolean clientsUseV2Format;
50-
private final int maxBodySize;
51-
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
5251

5352
public FileBasedRoutingGroupSelector(String rulesPath, Duration rulesRefreshPeriod, RequestAnalyzerConfig requestAnalyzerConfig)
5453
{
5554
analyzeRequest = requestAnalyzerConfig.isAnalyzeRequest();
56-
clientsUseV2Format = requestAnalyzerConfig.isClientsUseV2Format();
57-
maxBodySize = requestAnalyzerConfig.getMaxBodySize();
58-
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig);
5955

6056
rules = memoizeWithExpiration(() -> readRulesFromPath(Path.of(rulesPath)), rulesRefreshPeriod.toJavaTime());
6157
}
@@ -68,12 +64,9 @@ public RoutingSelectorResponse findRoutingDestination(HttpServletRequest request
6864

6965
Map<String, Object> data;
7066
if (analyzeRequest) {
71-
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(
72-
request,
73-
clientsUseV2Format,
74-
maxBodySize);
75-
TrinoRequestUser trinoRequestUser = trinoRequestUserProvider.getInstance(request);
76-
data = ImmutableMap.of("request", request, "trinoQueryProperties", trinoQueryProperties, "trinoRequestUser", trinoRequestUser);
67+
TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
68+
TrinoRequestUser trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER);
69+
data = ImmutableMap.of("request", request, TRINO_QUERY_PROPERTIES, trinoQueryProperties, TRINO_REQUEST_USER, trinoRequestUser);
7770
}
7871
else {
7972
data = ImmutableMap.of("request", request);

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,18 @@
6565
import io.trino.sql.tree.Table;
6666
import io.trino.sql.tree.TableFunctionInvocation;
6767
import io.trino.sql.tree.WithQuery;
68-
import jakarta.servlet.http.HttpServletRequest;
6968
import jakarta.ws.rs.HttpMethod;
69+
import jakarta.ws.rs.container.ContainerRequestContext;
70+
import jakarta.ws.rs.core.MediaType;
7071

7172
import java.io.BufferedReader;
7273
import java.io.IOException;
74+
import java.io.InputStream;
75+
import java.io.InputStreamReader;
7376
import java.net.URLDecoder;
77+
import java.nio.charset.StandardCharsets;
7478
import java.util.ArrayList;
79+
import java.util.Collections;
7580
import java.util.Enumeration;
7681
import java.util.HashSet;
7782
import java.util.List;
@@ -142,30 +147,29 @@ public TrinoQueryProperties(
142147
maxBodySize = -1;
143148
}
144149

145-
public TrinoQueryProperties(HttpServletRequest request, boolean isClientsUseV2Format, int maxBodySize)
150+
public TrinoQueryProperties(ContainerRequestContext requestContext, boolean isClientsUseV2Format, int maxBodySize)
146151
{
147-
requireNonNull(request, "request is null");
152+
requireNonNull(requestContext, "requestContext is null");
148153
this.isClientsUseV2Format = isClientsUseV2Format;
149154
this.maxBodySize = maxBodySize;
150155

151-
defaultCatalog = Optional.ofNullable(request.getHeader(TRINO_CATALOG_HEADER_NAME));
152-
defaultSchema = Optional.ofNullable(request.getHeader(TRINO_SCHEMA_HEADER_NAME));
153-
if (request.getMethod().equals(HttpMethod.POST)) {
156+
defaultCatalog = Optional.ofNullable(requestContext.getHeaderString(TRINO_CATALOG_HEADER_NAME));
157+
defaultSchema = Optional.ofNullable(requestContext.getHeaderString(TRINO_SCHEMA_HEADER_NAME));
158+
if (requestContext.getMethod().equals(HttpMethod.POST)) {
154159
isNewQuerySubmission = true;
155-
processRequestBody(request);
160+
processRequestBody(requestContext);
156161
}
157162
}
158163

159-
private void processRequestBody(HttpServletRequest request)
164+
private void processRequestBody(BufferedReader reader, Map<String, String> preparedStatements)
160165
{
161-
try (BufferedReader reader = request.getReader()) {
166+
try (reader) {
162167
if (reader == null) {
163168
log.warn("HTTP request returned null reader");
164169
body = "";
165170
return;
166171
}
167172

168-
Map<String, String> preparedStatements = getPreparedStatements(request);
169173
SqlParser parser = new SqlParser();
170174
reader.mark(maxBodySize);
171175
char[] buffer = new char[maxBodySize];
@@ -238,11 +242,45 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
238242
}
239243
}
240244

241-
private Map<String, String> getPreparedStatements(HttpServletRequest request)
245+
private void processRequestBody(ContainerRequestContext requestContext)
246+
{
247+
if (!requestContext.hasEntity()) {
248+
return;
249+
}
250+
251+
MediaType mediaType = requestContext.getMediaType();
252+
if (mediaType == null) {
253+
return;
254+
}
255+
256+
String charset = mediaType.getParameters().get("charset");
257+
if (!StandardCharsets.UTF_8.name().equalsIgnoreCase(charset)) {
258+
return;
259+
}
260+
261+
InputStream entityStream = requestContext.getEntityStream();
262+
try (InputStreamReader entityReader = new InputStreamReader(entityStream, StandardCharsets.UTF_8);
263+
BufferedReader reader = new BufferedReader(entityReader)) {
264+
processRequestBody(reader, getPreparedStatements(requestContext));
265+
}
266+
catch (IOException e) {
267+
log.warn("Error extracting request body for rules processing: %s", e.getMessage());
268+
errorMessage = Optional.of(e.getMessage());
269+
}
270+
catch (ParsingException e) {
271+
log.info("Could not parse request body as SQL: %s; Message: %s", body, e.getMessage());
272+
errorMessage = Optional.of(e.getMessage());
273+
}
274+
catch (RequestParsingException e) {
275+
log.warn(e, "Error parsing request for rules");
276+
errorMessage = Optional.of(e.getMessage());
277+
}
278+
}
279+
280+
private Map<String, String> getPreparedStatements(Enumeration<String> headers)
242281
throws RequestParsingException
243282
{
244283
ImmutableMap.Builder<String, String> preparedStatementsMapBuilder = ImmutableMap.builder();
245-
Enumeration<String> headers = request.getHeaders(TRINO_PREPARED_STATEMENT_HEADER_NAME);
246284
if (headers == null) {
247285
return preparedStatementsMapBuilder.build();
248286
}
@@ -259,6 +297,19 @@ private Map<String, String> getPreparedStatements(HttpServletRequest request)
259297
return preparedStatementsMapBuilder.build();
260298
}
261299

300+
private Map<String, String> getPreparedStatements(ContainerRequestContext requestContext)
301+
throws RequestParsingException
302+
{
303+
if (requestContext.getHeaders() == null) {
304+
return ImmutableMap.of();
305+
}
306+
List<String> headers = requestContext.getHeaders().get(TRINO_PREPARED_STATEMENT_HEADER_NAME);
307+
if (headers == null || headers.isEmpty()) {
308+
return ImmutableMap.of();
309+
}
310+
return getPreparedStatements(Collections.enumeration(headers));
311+
}
312+
262313
private String decodePreparedStatementFromHeader(String headerValue)
263314
{
264315
// From io.trino.server.protocol.PreparedStatementEncoder

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@
3333
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
3434
import io.airlift.log.Logger;
3535
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
36-
import jakarta.servlet.http.Cookie;
37-
import jakarta.servlet.http.HttpServletRequest;
36+
import jakarta.ws.rs.container.ContainerRequestContext;
37+
import jakarta.ws.rs.core.HttpHeaders;
3838

3939
import java.io.IOException;
4040
import java.net.URI;
4141
import java.nio.charset.StandardCharsets;
42-
import java.util.Arrays;
4342
import java.util.Base64;
43+
import java.util.Map;
4444
import java.util.Optional;
4545
import java.util.concurrent.ExecutionException;
4646
import java.util.concurrent.TimeUnit;
@@ -62,7 +62,7 @@ public class TrinoRequestUser
6262

6363
private final Optional<LoadingCache<String, UserInfo>> userInfoCache;
6464

65-
private TrinoRequestUser(HttpServletRequest request, String userField, Optional<LoadingCache<String, UserInfo>> userInfoCache)
65+
private TrinoRequestUser(ContainerRequestContext request, String userField, Optional<LoadingCache<String, UserInfo>> userInfoCache)
6666
{
6767
this.userInfoCache = requireNonNull(userInfoCache);
6868
user = extractUser(request, userField);
@@ -106,15 +106,17 @@ public boolean userExistsAndEquals(String testUser)
106106
return user.filter(testUser::equals).isPresent();
107107
}
108108

109-
private Optional<String> extractUserFromCookies(HttpServletRequest request, String userField)
109+
private Optional<String> extractUserFromCookies(ContainerRequestContext requestContext, String userField)
110110
{
111-
if (request.getCookies() == null) {
111+
Map<String, jakarta.ws.rs.core.Cookie> cookies = requestContext.getCookies();
112+
if (cookies == null || cookies.isEmpty()) {
113+
log.debug("cookies are empty");
112114
return Optional.empty();
113115
}
114-
log.debug("Trying to get user from cookie");
115-
Optional<Cookie> uiToken = Arrays.stream(request.getCookies())
116-
.filter(cookie -> cookie.getName().equals(TRINO_UI_TOKEN_NAME) || cookie.getName().equals(TRINO_SECURE_UI_TOKEN_NAME))
117-
.findAny();
116+
117+
log.debug("Trying to get user from cookie from ContainerRequestContext");
118+
Optional<jakarta.ws.rs.core.Cookie> uiToken = Optional.ofNullable(cookies.get(TRINO_UI_TOKEN_NAME))
119+
.or(() -> Optional.ofNullable(cookies.get(TRINO_SECURE_UI_TOKEN_NAME)));
118120

119121
return uiToken.map(t -> {
120122
try {
@@ -129,20 +131,20 @@ private Optional<String> extractUserFromCookies(HttpServletRequest request, Stri
129131
});
130132
}
131133

132-
private Optional<String> extractUser(HttpServletRequest request, String userField)
134+
private Optional<String> extractUser(ContainerRequestContext requestContext, String userField)
133135
{
134136
String header;
135-
header = request.getHeader(TRINO_USER_HEADER_NAME);
137+
header = requestContext.getHeaderString(TRINO_USER_HEADER_NAME);
136138
if (header != null) {
137139
return Optional.of(header);
138140
}
139141

140-
Optional<String> user = extractUserFromAuthorizationHeader(request.getHeader("Authorization"), userField);
142+
Optional<String> user = extractUserFromAuthorizationHeader(requestContext.getHeaderString(HttpHeaders.AUTHORIZATION), userField);
141143
if (user.isPresent()) {
142144
return user;
143145
}
144146

145-
return extractUserFromCookies(request, userField);
147+
return extractUserFromCookies(requestContext, userField);
146148
}
147149

148150
private Optional<String> extractUserFromAuthorizationHeader(String header, String userField)
@@ -225,7 +227,7 @@ public TrinoRequestUserProvider(RequestAnalyzerConfig config)
225227
}
226228
}
227229

228-
public TrinoRequestUser getInstance(HttpServletRequest request)
230+
public TrinoRequestUser getInstance(ContainerRequestContext request)
229231
{
230232
return new TrinoRequestUser(request, userField, userInfoCache);
231233
}

0 commit comments

Comments
 (0)