Skip to content

Commit b97df7f

Browse files
committed
Add request filter for early user and query parsing
1 parent a37a345 commit b97df7f

23 files changed

+1166
-251
lines changed

gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import io.trino.gateway.ha.router.RoutingRulesManager;
4040
import io.trino.gateway.ha.router.StochasticRoutingManager;
4141
import io.trino.gateway.ha.security.AuthorizedExceptionMapper;
42+
import io.trino.gateway.ha.security.QueryMetadataParser;
43+
import io.trino.gateway.ha.security.QueryUserInfoParser;
4244
import io.trino.gateway.proxyserver.ForProxy;
4345
import io.trino.gateway.proxyserver.ProxyRequestHandler;
4446
import io.trino.gateway.proxyserver.RouteToBackendResource;
@@ -187,6 +189,8 @@ private static void registerProxyResources(Binder binder)
187189
{
188190
jaxrsBinder(binder).bind(RouteToBackendResource.class);
189191
jaxrsBinder(binder).bind(RouterPreMatchContainerRequestFilter.class);
192+
jaxrsBinder(binder).bind(QueryUserInfoParser.class);
193+
jaxrsBinder(binder).bind(QueryMetadataParser.class);
190194
jaxrsBinder(binder).bind(ProxyRequestHandler.class);
191195
httpClientBinder(binder).bindHttpClient("proxy", ForProxy.class);
192196
httpClientBinder(binder).bindHttpClient("monitor", ForMonitor.class);

gateway-ha/src/main/java/io/trino/gateway/ha/handler/HttpUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ public class HttpUtils
2525
public static final String TRINO_UI_PATH = "/ui";
2626
public static final String OAUTH_PATH = "/oauth2";
2727
public static final String USER_HEADER = "X-Trino-User";
28+
public static final String TRINO_REQUEST_USER = "trinoRequestUser";
29+
public static final String TRINO_QUERY_PROPERTIES = "trinoQueryProperties";
2830

2931
private HttpUtils() {}
3032
}

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.util.regex.Pattern;
2929

3030
import static com.google.common.base.Strings.isNullOrEmpty;
31+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
3132
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
3233
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
3334
import static java.nio.charset.StandardCharsets.UTF_8;
@@ -74,7 +75,7 @@ public static Optional<String> extractQueryIdIfPresent(
7475
throw new RuntimeException("Error reading request body", e);
7576
}
7677
if (!isNullOrEmpty(queryText) && queryText.toLowerCase(ENGLISH).contains("kill_query")) {
77-
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
78+
TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
7879
return trinoQueryProperties.getQueryId();
7980
}
8081
return Optional.empty();

gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,10 @@
3333
import java.util.List;
3434
import java.util.Map;
3535
import java.util.Optional;
36-
import java.util.regex.Pattern;
3736
import java.util.stream.Stream;
3837

3938
import static com.google.common.base.Strings.isNullOrEmpty;
40-
import static com.google.common.collect.ImmutableList.toImmutableList;
41-
import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH;
42-
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
43-
import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH;
4439
import static io.trino.gateway.ha.handler.HttpUtils.USER_HEADER;
45-
import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH;
46-
import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH;
47-
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
4840
import static io.trino.gateway.ha.handler.ProxyUtils.buildUriWithNewCluster;
4941
import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
5042
import static java.util.Objects.requireNonNull;
@@ -56,7 +48,6 @@ public class RoutingTargetHandler
5648
private final RoutingGroupSelector routingGroupSelector;
5749
private final String defaultRoutingGroup;
5850
private final List<String> statementPaths;
59-
private final List<Pattern> extraWhitelistPaths;
6051
private final boolean requestAnalyserClientsUseV2Format;
6152
private final int requestAnalyserMaxBodySize;
6253
private final boolean cookiesEnabled;
@@ -71,7 +62,6 @@ public RoutingTargetHandler(
7162
this.routingGroupSelector = requireNonNull(routingGroupSelector);
7263
this.defaultRoutingGroup = haGatewayConfiguration.getRouting().getDefaultRoutingGroup();
7364
statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths());
74-
extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
7565
requestAnalyserClientsUseV2Format = haGatewayConfiguration.getRequestAnalyzerConfig().isClientsUseV2Format();
7666
requestAnalyserMaxBodySize = haGatewayConfiguration.getRequestAnalyzerConfig().getMaxBodySize();
7767
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
@@ -118,18 +108,6 @@ private RoutingTargetResponse getRoutingTargetResponse(HttpServletRequest reques
118108
modifiedRequest);
119109
}
120110

121-
public boolean isPathWhiteListed(String path)
122-
{
123-
return statementPaths.stream().anyMatch(path::startsWith)
124-
|| path.startsWith(V1_QUERY_PATH)
125-
|| path.startsWith(TRINO_UI_PATH)
126-
|| path.startsWith(V1_INFO_PATH)
127-
|| path.startsWith(V1_NODE_PATH)
128-
|| path.startsWith(UI_API_STATS_PATH)
129-
|| path.startsWith(OAUTH_PATH)
130-
|| extraWhitelistPaths.stream().anyMatch(pattern -> pattern.matcher(path).matches());
131-
}
132-
133111
/**
134112
* A wrapper for HttpServletRequest that allows modifying multiple headers.
135113
*/

gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import com.google.inject.AbstractModule;
1818
import com.google.inject.Provides;
1919
import com.google.inject.Singleton;
20+
import com.google.inject.name.Named;
2021
import io.airlift.http.client.HttpClient;
2122
import io.trino.gateway.ha.clustermonitor.ClusterStatsHttpMonitor;
2223
import io.trino.gateway.ha.clustermonitor.ClusterStatsInfoApiMonitor;
@@ -46,6 +47,7 @@
4647
import io.trino.gateway.ha.router.HaGatewayManager;
4748
import io.trino.gateway.ha.router.HaQueryHistoryManager;
4849
import io.trino.gateway.ha.router.HaResourceGroupsManager;
50+
import io.trino.gateway.ha.router.PathFilter;
4951
import io.trino.gateway.ha.router.QueryHistoryManager;
5052
import io.trino.gateway.ha.router.ResourceGroupsManager;
5153
import io.trino.gateway.ha.router.RoutingGroupSelector;
@@ -96,6 +98,7 @@ protected void configure()
9698
binder().bind(ResourceGroupsManager.class).toInstance(resourceGroupsManager);
9799
binder().bind(GatewayBackendManager.class).toInstance(gatewayBackendManager);
98100
binder().bind(QueryHistoryManager.class).toInstance(queryHistoryManager);
101+
binder().bind(PathFilter.class);
99102
}
100103

101104
public HaGatewayProviderModule(HaGatewayConfiguration configuration)
@@ -278,4 +281,20 @@ public MonitorConfiguration getMonitorConfiguration()
278281
{
279282
return configuration.getMonitor();
280283
}
284+
285+
@Provides
286+
@Singleton
287+
@Named("statementPaths")
288+
public List<String> getStatementPaths()
289+
{
290+
return configuration.getStatementPaths();
291+
}
292+
293+
@Provides
294+
@Singleton
295+
@Named("extraWhitelistPaths")
296+
public List<String> getExtraWhitelistPaths()
297+
{
298+
return configuration.getExtraWhitelistPaths();
299+
}
281300
}

gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler;
4747
import static io.airlift.http.client.Request.Builder.preparePost;
4848
import static io.airlift.json.JsonCodec.jsonCodec;
49+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
50+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER;
4951
import static java.util.Collections.list;
5052
import static java.util.Objects.requireNonNull;
5153

@@ -58,7 +60,6 @@ public class ExternalRoutingGroupSelector
5860
private final boolean propagateErrors;
5961
private final HttpClient httpClient;
6062
private final RequestAnalyzerConfig requestAnalyzerConfig;
61-
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
6263
private static final JsonCodec<RoutingGroupExternalBody> ROUTING_GROUP_EXTERNAL_BODY_JSON_CODEC = jsonCodec(RoutingGroupExternalBody.class);
6364
private static final JsonResponseHandler<ExternalRouterResponse> ROUTING_GROUP_EXTERNAL_RESPONSE_JSON_RESPONSE_HANDLER =
6465
createJsonResponseHandler(jsonCodec(ExternalRouterResponse.class));
@@ -74,7 +75,6 @@ public class ExternalRoutingGroupSelector
7475
propagateErrors = rulesExternalConfiguration.isPropagateErrors();
7576

7677
this.requestAnalyzerConfig = requestAnalyzerConfig;
77-
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig);
7878
try {
7979
this.uri = new URI(requireNonNull(rulesExternalConfiguration.getUrlPath(),
8080
"Invalid URL provided, using routing group header as default."));
@@ -143,8 +143,8 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request)
143143
TrinoQueryProperties trinoQueryProperties = null;
144144
TrinoRequestUser trinoRequestUser = null;
145145
if (requestAnalyzerConfig.isAnalyzeRequest()) {
146-
trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig.isClientsUseV2Format(), requestAnalyzerConfig.getMaxBodySize());
147-
trinoRequestUser = trinoRequestUserProvider.getInstance(request);
146+
trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
147+
trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER);
148148
}
149149

150150
return new RoutingGroupExternalBody(

gateway-ha/src/main/java/io/trino/gateway/ha/router/FileBasedRoutingGroupSelector.java

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import java.util.Map;
3434

3535
import static com.google.common.base.Suppliers.memoizeWithExpiration;
36+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_QUERY_PROPERTIES;
37+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_REQUEST_USER;
3638
import static java.nio.charset.StandardCharsets.UTF_8;
3739
import static java.util.Collections.sort;
3840

@@ -46,16 +48,10 @@ public class FileBasedRoutingGroupSelector
4648

4749
private final Supplier<List<RoutingRule>> rules;
4850
private final boolean analyzeRequest;
49-
private final boolean clientsUseV2Format;
50-
private final int maxBodySize;
51-
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
5251

5352
public FileBasedRoutingGroupSelector(String rulesPath, Duration rulesRefreshPeriod, RequestAnalyzerConfig requestAnalyzerConfig)
5453
{
5554
analyzeRequest = requestAnalyzerConfig.isAnalyzeRequest();
56-
clientsUseV2Format = requestAnalyzerConfig.isClientsUseV2Format();
57-
maxBodySize = requestAnalyzerConfig.getMaxBodySize();
58-
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig);
5955

6056
rules = memoizeWithExpiration(() -> readRulesFromPath(Path.of(rulesPath)), rulesRefreshPeriod.toJavaTime());
6157
}
@@ -68,12 +64,9 @@ public RoutingSelectorResponse findRoutingDestination(HttpServletRequest request
6864

6965
Map<String, Object> data;
7066
if (analyzeRequest) {
71-
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(
72-
request,
73-
clientsUseV2Format,
74-
maxBodySize);
75-
TrinoRequestUser trinoRequestUser = trinoRequestUserProvider.getInstance(request);
76-
data = ImmutableMap.of("request", request, "trinoQueryProperties", trinoQueryProperties, "trinoRequestUser", trinoRequestUser);
67+
TrinoQueryProperties trinoQueryProperties = (TrinoQueryProperties) request.getAttribute(TRINO_QUERY_PROPERTIES);
68+
TrinoRequestUser trinoRequestUser = (TrinoRequestUser) request.getAttribute(TRINO_REQUEST_USER);
69+
data = ImmutableMap.of("request", request, TRINO_QUERY_PROPERTIES, trinoQueryProperties, TRINO_REQUEST_USER, trinoRequestUser);
7770
}
7871
else {
7972
data = ImmutableMap.of("request", request);
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.gateway.ha.router;
15+
16+
import com.google.inject.Inject;
17+
import com.google.inject.Singleton;
18+
import com.google.inject.name.Named;
19+
20+
import java.util.List;
21+
import java.util.Set;
22+
import java.util.regex.Pattern;
23+
24+
import static com.google.common.collect.ImmutableList.toImmutableList;
25+
import static io.trino.gateway.ha.handler.HttpUtils.OAUTH_PATH;
26+
import static io.trino.gateway.ha.handler.HttpUtils.TRINO_UI_PATH;
27+
import static io.trino.gateway.ha.handler.HttpUtils.UI_API_STATS_PATH;
28+
import static io.trino.gateway.ha.handler.HttpUtils.V1_INFO_PATH;
29+
import static io.trino.gateway.ha.handler.HttpUtils.V1_NODE_PATH;
30+
import static io.trino.gateway.ha.handler.HttpUtils.V1_QUERY_PATH;
31+
import static java.util.Objects.requireNonNull;
32+
33+
/**
34+
* A filter component that determines whether a given path should be whitelisted
35+
* for routing to Trino clusters.
36+
*/
37+
@Singleton
38+
public class PathFilter
39+
{
40+
private final Set<String> statementPaths;
41+
private final List<Pattern> extraWhitelistPatterns;
42+
43+
@Inject
44+
public PathFilter(
45+
@Named("statementPaths") List<String> statementPaths,
46+
@Named("extraWhitelistPaths") List<String> extraWhitelistPaths)
47+
{
48+
this.statementPaths = Set.copyOf(statementPaths);
49+
this.extraWhitelistPatterns = requireNonNull(extraWhitelistPaths).stream()
50+
.map(Pattern::compile)
51+
.collect(toImmutableList());
52+
}
53+
54+
/**
55+
* Determines if the given path is whitelisted for routing to backend.
56+
*
57+
* @param path the request path to check
58+
* @return true if the path should be routed to backend, false otherwise
59+
*/
60+
public boolean isPathWhiteListed(String path)
61+
{
62+
return statementPaths.stream().anyMatch(path::startsWith)
63+
|| path.startsWith(V1_QUERY_PATH)
64+
|| path.startsWith(TRINO_UI_PATH)
65+
|| path.startsWith(V1_INFO_PATH)
66+
|| path.startsWith(V1_NODE_PATH)
67+
|| path.startsWith(UI_API_STATS_PATH)
68+
|| path.startsWith(OAUTH_PATH)
69+
|| extraWhitelistPatterns.stream().anyMatch(pattern -> pattern.matcher(path).matches());
70+
}
71+
}

0 commit comments

Comments
 (0)