Skip to content

Commit 239f66a

Browse files
committed
Add request filter for early user and query parsing
1 parent 44a332e commit 239f66a

16 files changed

+284
-252
lines changed

gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import io.trino.gateway.ha.router.RoutingRulesManager;
4040
import io.trino.gateway.ha.router.StochasticRoutingManager;
4141
import io.trino.gateway.ha.security.AuthorizedExceptionMapper;
42+
import io.trino.gateway.ha.security.QueryMetadataParser;
43+
import io.trino.gateway.ha.security.QueryUserInfoParser;
4244
import io.trino.gateway.proxyserver.ForProxy;
4345
import io.trino.gateway.proxyserver.ProxyRequestHandler;
4446
import io.trino.gateway.proxyserver.RouteToBackendResource;
@@ -187,6 +189,8 @@ private static void registerProxyResources(Binder binder)
187189
{
188190
jaxrsBinder(binder).bind(RouteToBackendResource.class);
189191
jaxrsBinder(binder).bind(RouterPreMatchContainerRequestFilter.class);
192+
jaxrsBinder(binder).bind(QueryUserInfoParser.class);
193+
jaxrsBinder(binder).bind(QueryMetadataParser.class);
190194
jaxrsBinder(binder).bind(ProxyRequestHandler.class);
191195
httpClientBinder(binder).bindHttpClient("proxy", ForProxy.class);
192196
httpClientBinder(binder).bindHttpClient("monitor", ForMonitor.class);

gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,17 @@
2020
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Map;
23-
23+
import java.util.regex.Pattern;
24+
25+
import static com.google.common.collect.ImmutableList.toImmutableList;
26+
import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH;
27+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
28+
import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH;
29+
import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH;
30+
import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH;
31+
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
2432
import static io.trino.gateway.ha.handler.HttpUtils.V1_STATEMENT_PATH;
33+
import static java.util.Objects.requireNonNull;
2534

2635
public class HaGatewayConfiguration
2736
{
@@ -289,6 +298,19 @@ private void validateStatementPath(String statementPath, List<String> statementP
289298
}
290299
}
291300

301+
public boolean isPathWhiteListed(String path)
302+
{
303+
List<Pattern> extraWhitelistPaths = requireNonNull(this.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
304+
return statementPaths.stream().anyMatch(path::startsWith)
305+
|| path.startsWith(V1_QUERY_PATH)
306+
|| path.startsWith(TRINO_UI_PATH)
307+
|| path.startsWith(V1_INFO_PATH)
308+
|| path.startsWith(V1_NODE_PATH)
309+
|| path.startsWith(UI_API_STATS_PATH)
310+
|| path.startsWith(OAUTH_PATH)
311+
|| extraWhitelistPaths.stream().anyMatch(pattern -> pattern.matcher(path).matches());
312+
}
313+
292314
public static class HaGatewayConfigurationException
293315
extends RuntimeException
294316
{

gateway-ha/src/main/java/io/trino/gateway/ha/handler/HttpUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ public class HttpUtils
2525
public static final String TRINO_UI_PATH = "/ui";
2626
public static final String OAUTH_PATH = "/oauth2";
2727
public static final String USER_HEADER = "X-Trino-User";
28+
public static final String TRINO_REQUEST_USER = "trinoRequestUser";
29+
public static final String TRINO_QUERY_PROPERTIES = "trinoQueryProperties";
2830

2931
private HttpUtils() {}
3032
}

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.util.regex.Pattern;
2929

3030
import static com.google.common.base.Strings.isNullOrEmpty;
31+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
3132
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
3233
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
3334
import static java.nio.charset.StandardCharsets.UTF_8;
@@ -74,7 +75,7 @@ public static Optional<String> extractQueryIdIfPresent(
7475
throw new RuntimeException("Error reading request body", e);
7576
}
7677
if (!isNullOrEmpty(queryText) && queryText.toLowerCase(ENGLISH).contains("kill_query")) {
77-
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
78+
TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
7879
return trinoQueryProperties.getQueryId();
7980
}
8081
return Optional.empty();

gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,10 @@
3333
import java.util.List;
3434
import java.util.Map;
3535
import java.util.Optional;
36-
import java.util.regex.Pattern;
3736
import java.util.stream.Stream;
3837

3938
import static com.google.common.base.Strings.isNullOrEmpty;
40-
import static com.google.common.collect.ImmutableList.toImmutableList;
41-
import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH;
42-
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
43-
import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH;
4439
import static io.trino.gateway.ha.handler.HttpUtils.USER_HEADER;
45-
import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH;
46-
import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH;
47-
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
4840
import static io.trino.gateway.ha.handler.ProxyUtils.buildUriWithNewCluster;
4941
import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
5042
import static java.util.Objects.requireNonNull;
@@ -56,7 +48,6 @@ public class RoutingTargetHandler
5648
private final RoutingGroupSelector routingGroupSelector;
5749
private final String defaultRoutingGroup;
5850
private final List<String> statementPaths;
59-
private final List<Pattern> extraWhitelistPaths;
6051
private final boolean requestAnalyserClientsUseV2Format;
6152
private final int requestAnalyserMaxBodySize;
6253
private final boolean cookiesEnabled;
@@ -71,7 +62,6 @@ public RoutingTargetHandler(
7162
this.routingGroupSelector = requireNonNull(routingGroupSelector);
7263
this.defaultRoutingGroup = haGatewayConfiguration.getRouting().getDefaultRoutingGroup();
7364
statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths());
74-
extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
7565
requestAnalyserClientsUseV2Format = haGatewayConfiguration.getRequestAnalyzerConfig().isClientsUseV2Format();
7666
requestAnalyserMaxBodySize = haGatewayConfiguration.getRequestAnalyzerConfig().getMaxBodySize();
7767
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
@@ -118,18 +108,6 @@ private RoutingTargetResponse getRoutingTargetResponse(HttpServletRequest reques
118108
modifiedRequest);
119109
}
120110

121-
public boolean isPathWhiteListed(String path)
122-
{
123-
return statementPaths.stream().anyMatch(path::startsWith)
124-
|| path.startsWith(V1_QUERY_PATH)
125-
|| path.startsWith(TRINO_UI_PATH)
126-
|| path.startsWith(V1_INFO_PATH)
127-
|| path.startsWith(V1_NODE_PATH)
128-
|| path.startsWith(UI_API_STATS_PATH)
129-
|| path.startsWith(OAUTH_PATH)
130-
|| extraWhitelistPaths.stream().anyMatch(pattern -> pattern.matcher(path).matches());
131-
}
132-
133111
/**
134112
* A wrapper for HttpServletRequest that allows modifying multiple headers.
135113
*/

gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler;
4747
import static io.airlift.http.client.Request.Builder.preparePost;
4848
import static io.airlift.json.JsonCodec.jsonCodec;
49+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
50+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER;
4951
import static java.util.Collections.list;
5052
import static java.util.Objects.requireNonNull;
5153

@@ -58,7 +60,6 @@ public class ExternalRoutingGroupSelector
5860
private final boolean propagateErrors;
5961
private final HttpClient httpClient;
6062
private final RequestAnalyzerConfig requestAnalyzerConfig;
61-
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
6263
private static final JsonCodec<RoutingGroupExternalBody> ROUTING_GROUP_EXTERNAL_BODY_JSON_CODEC = jsonCodec(RoutingGroupExternalBody.class);
6364
private static final JsonResponseHandler<ExternalRouterResponse> ROUTING_GROUP_EXTERNAL_RESPONSE_JSON_RESPONSE_HANDLER =
6465
createJsonResponseHandler(jsonCodec(ExternalRouterResponse.class));
@@ -74,7 +75,6 @@ public class ExternalRoutingGroupSelector
7475
propagateErrors = rulesExternalConfiguration.isPropagateErrors();
7576

7677
this.requestAnalyzerConfig = requestAnalyzerConfig;
77-
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig);
7878
try {
7979
this.uri = new URI(requireNonNull(rulesExternalConfiguration.getUrlPath(),
8080
"Invalid URL provided, using routing group header as default."));
@@ -143,8 +143,8 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request)
143143
TrinoQueryProperties trinoQueryProperties = null;
144144
TrinoRequestUser trinoRequestUser = null;
145145
if (requestAnalyzerConfig.isAnalyzeRequest()) {
146-
trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig.isClientsUseV2Format(), requestAnalyzerConfig.getMaxBodySize());
147-
trinoRequestUser = trinoRequestUserProvider.getInstance(request);
146+
trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
147+
trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER);
148148
}
149149

150150
return new RoutingGroupExternalBody(

gateway-ha/src/main/java/io/trino/gateway/ha/router/FileBasedRoutingGroupSelector.java

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import java.util.Map;
3434

3535
import static com.google.common.base.Suppliers.memoizeWithExpiration;
36+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
37+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER;
3638
import static java.nio.charset.StandardCharsets.UTF_8;
3739
import static java.util.Collections.sort;
3840

@@ -46,16 +48,10 @@ public class FileBasedRoutingGroupSelector
4648

4749
private final Supplier<List<RoutingRule>> rules;
4850
private final boolean analyzeRequest;
49-
private final boolean clientsUseV2Format;
50-
private final int maxBodySize;
51-
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
5251

5352
public FileBasedRoutingGroupSelector(String rulesPath, Duration rulesRefreshPeriod, RequestAnalyzerConfig requestAnalyzerConfig)
5453
{
5554
analyzeRequest = requestAnalyzerConfig.isAnalyzeRequest();
56-
clientsUseV2Format = requestAnalyzerConfig.isClientsUseV2Format();
57-
maxBodySize = requestAnalyzerConfig.getMaxBodySize();
58-
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig);
5955

6056
rules = memoizeWithExpiration(() -> readRulesFromPath(Path.of(rulesPath)), rulesRefreshPeriod.toJavaTime());
6157
}
@@ -68,12 +64,9 @@ public RoutingSelectorResponse findRoutingDestination(HttpServletRequest request
6864

6965
Map<String, Object> data;
7066
if (analyzeRequest) {
71-
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(
72-
request,
73-
clientsUseV2Format,
74-
maxBodySize);
75-
TrinoRequestUser trinoRequestUser = trinoRequestUserProvider.getInstance(request);
76-
data = ImmutableMap.of("request", request, "trinoQueryProperties", trinoQueryProperties, "trinoRequestUser", trinoRequestUser);
67+
TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
68+
TrinoRequestUser trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER);
69+
data = ImmutableMap.of("request", request, TRINO_QUERY_PROPERTIES, trinoQueryProperties, TRINO_REQUEST_USER, trinoRequestUser);
7770
}
7871
else {
7972
data = ImmutableMap.of("request", request);

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,18 @@
6565
import io.trino.sql.tree.Table;
6666
import io.trino.sql.tree.TableFunctionInvocation;
6767
import io.trino.sql.tree.WithQuery;
68-
import jakarta.servlet.http.HttpServletRequest;
6968
import jakarta.ws.rs.HttpMethod;
69+
import jakarta.ws.rs.container.ContainerRequestContext;
70+
import jakarta.ws.rs.core.MediaType;
7071

7172
import java.io.BufferedReader;
7273
import java.io.IOException;
74+
import java.io.InputStream;
75+
import java.io.InputStreamReader;
7376
import java.net.URLDecoder;
77+
import java.nio.charset.StandardCharsets;
7478
import java.util.ArrayList;
79+
import java.util.Collections;
7580
import java.util.Enumeration;
7681
import java.util.HashSet;
7782
import java.util.List;
@@ -142,35 +147,38 @@ public TrinoQueryProperties(
142147
maxBodySize = -1;
143148
}
144149

145-
public TrinoQueryProperties(HttpServletRequest request, boolean isClientsUseV2Format, int maxBodySize)
150+
public TrinoQueryProperties(ContainerRequestContext requestContext, boolean isClientsUseV2Format, int maxBodySize)
146151
{
147-
requireNonNull(request, "request is null");
152+
requireNonNull(requestContext, "requestContext is null");
148153
this.isClientsUseV2Format = isClientsUseV2Format;
149154
this.maxBodySize = maxBodySize;
150155

151-
defaultCatalog = Optional.ofNullable(request.getHeader(TRINO_CATALOG_HEADER_NAME));
152-
defaultSchema = Optional.ofNullable(request.getHeader(TRINO_SCHEMA_HEADER_NAME));
153-
if (request.getMethod().equals(HttpMethod.POST)) {
156+
defaultCatalog = Optional.ofNullable(requestContext.getHeaderString(TRINO_CATALOG_HEADER_NAME));
157+
defaultSchema = Optional.ofNullable(requestContext.getHeaderString(TRINO_SCHEMA_HEADER_NAME));
158+
if (requestContext.getMethod().equals(HttpMethod.POST)) {
154159
isNewQuerySubmission = true;
155-
processRequestBody(request);
160+
processRequestBody(requestContext);
156161
}
157162
}
158163

159-
private void processRequestBody(HttpServletRequest request)
164+
private void processRequestBody(BufferedReader reader, Map<String, String> preparedStatements)
160165
{
161-
try (BufferedReader reader = request.getReader()) {
166+
try (reader) {
162167
if (reader == null) {
163168
log.warn("HTTP request returned null reader");
164169
body = "";
165170
return;
166171
}
167172

168-
Map<String, String> preparedStatements = getPreparedStatements(request);
169173
SqlParser parser = new SqlParser();
170174
reader.mark(maxBodySize);
171175
char[] buffer = new char[maxBodySize];
172176
int nChars = reader.read(buffer, 0, maxBodySize);
173177
reader.reset();
178+
if (nChars <= 0) {
179+
log.warn("query text is empty");
180+
return;
181+
}
174182
if (nChars == maxBodySize) {
175183
log.warn("Query length greater or equal to requestAnalyzerConfig.maxBodySize detected");
176184
return;
@@ -238,11 +246,45 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
238246
}
239247
}
240248

241-
private Map<String, String> getPreparedStatements(HttpServletRequest request)
249+
private void processRequestBody(ContainerRequestContext requestContext)
250+
{
251+
if (!requestContext.hasEntity()) {
252+
return;
253+
}
254+
255+
MediaType mediaType = requestContext.getMediaType();
256+
if (mediaType == null) {
257+
return;
258+
}
259+
260+
String charset = mediaType.getParameters().get("charset");
261+
if (!StandardCharsets.UTF_8.name().equalsIgnoreCase(charset)) {
262+
return;
263+
}
264+
265+
InputStream entityStream = requestContext.getEntityStream();
266+
try (InputStreamReader entityReader = new InputStreamReader(entityStream, StandardCharsets.UTF_8);
267+
BufferedReader reader = new BufferedReader(entityReader)) {
268+
processRequestBody(reader, getPreparedStatements(requestContext));
269+
}
270+
catch (IOException e) {
271+
log.warn("Error extracting request body for rules processing: %s", e.getMessage());
272+
errorMessage = Optional.of(e.getMessage());
273+
}
274+
catch (ParsingException e) {
275+
log.info("Could not parse request body as SQL: %s; Message: %s", body, e.getMessage());
276+
errorMessage = Optional.of(e.getMessage());
277+
}
278+
catch (RequestParsingException e) {
279+
log.warn(e, "Error parsing request for rules");
280+
errorMessage = Optional.of(e.getMessage());
281+
}
282+
}
283+
284+
private Map<String, String> getPreparedStatements(Enumeration<String> headers)
242285
throws RequestParsingException
243286
{
244287
ImmutableMap.Builder<String, String> preparedStatementsMapBuilder = ImmutableMap.builder();
245-
Enumeration<String> headers = request.getHeaders(TRINO_PREPARED_STATEMENT_HEADER_NAME);
246288
if (headers == null) {
247289
return preparedStatementsMapBuilder.build();
248290
}
@@ -259,6 +301,19 @@ private Map<String, String> getPreparedStatements(HttpServletRequest request)
259301
return preparedStatementsMapBuilder.build();
260302
}
261303

304+
private Map<String, String> getPreparedStatements(ContainerRequestContext requestContext)
305+
throws RequestParsingException
306+
{
307+
if (requestContext.getHeaders() == null) {
308+
return ImmutableMap.of();
309+
}
310+
List<String> headers = requestContext.getHeaders().get(TRINO_PREPARED_STATEMENT_HEADER_NAME);
311+
if (headers == null || headers.isEmpty()) {
312+
return ImmutableMap.of();
313+
}
314+
return getPreparedStatements(Collections.enumeration(headers));
315+
}
316+
262317
private String decodePreparedStatementFromHeader(String headerValue)
263318
{
264319
// From io.trino.server.protocol.PreparedStatementEncoder

0 commit comments

Comments
 (0)