Skip to content

Commit 6c5e92f

Browse files
committed
Fix HttpSecurity.addFilter* Ordering
Closes gh-9633
1 parent d3af4f7 commit 6c5e92f

File tree

3 files changed

+199
-92
lines changed

3 files changed

+199
-92
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java renamed to config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistration.java

+3-75
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*/
1616
package org.springframework.security.config.annotation.web.builders;
1717

18-
import java.io.Serializable;
1918
import java.util.Comparator;
2019
import java.util.HashMap;
2120
import java.util.Map;
@@ -53,14 +52,12 @@
5352
* @author Rob Winch
5453
* @since 3.2
5554
*/
56-
57-
@SuppressWarnings("serial")
58-
final class FilterComparator implements Comparator<Filter>, Serializable {
55+
final class FilterOrderRegistration {
5956
private static final int INITIAL_ORDER = 100;
6057
private static final int ORDER_STEP = 100;
6158
private final Map<String, Integer> filterToOrder = new HashMap<>();
6259

63-
FilterComparator() {
60+
FilterOrderRegistration() {
6461
Step order = new Step(INITIAL_ORDER, ORDER_STEP);
6562
put(ChannelProcessingFilter.class, order.next());
6663
put(ConcurrentSessionFilter.class, order.next());
@@ -111,75 +108,6 @@ final class FilterComparator implements Comparator<Filter>, Serializable {
111108
put(SwitchUserFilter.class, order.next());
112109
}
113110

114-
public int compare(Filter lhs, Filter rhs) {
115-
Integer left = getOrder(lhs.getClass());
116-
Integer right = getOrder(rhs.getClass());
117-
return left - right;
118-
}
119-
120-
/**
121-
* Determines if a particular {@link Filter} is registered to be sorted
122-
*
123-
* @param filter
124-
* @return
125-
*/
126-
public boolean isRegistered(Class<? extends Filter> filter) {
127-
return getOrder(filter) != null;
128-
}
129-
130-
/**
131-
* Registers a {@link Filter} to exist after a particular {@link Filter} that is
132-
* already registered.
133-
* @param filter the {@link Filter} to register
134-
* @param afterFilter the {@link Filter} that is already registered and that
135-
* {@code filter} should be placed after.
136-
*/
137-
public void registerAfter(Class<? extends Filter> filter,
138-
Class<? extends Filter> afterFilter) {
139-
Integer position = getOrder(afterFilter);
140-
if (position == null) {
141-
throw new IllegalArgumentException(
142-
"Cannot register after unregistered Filter " + afterFilter);
143-
}
144-
145-
put(filter, position + 1);
146-
}
147-
148-
/**
149-
* Registers a {@link Filter} to exist at a particular {@link Filter} position
150-
* @param filter the {@link Filter} to register
151-
* @param atFilter the {@link Filter} that is already registered and that
152-
* {@code filter} should be placed at.
153-
*/
154-
public void registerAt(Class<? extends Filter> filter,
155-
Class<? extends Filter> atFilter) {
156-
Integer position = getOrder(atFilter);
157-
if (position == null) {
158-
throw new IllegalArgumentException(
159-
"Cannot register after unregistered Filter " + atFilter);
160-
}
161-
162-
put(filter, position);
163-
}
164-
165-
/**
166-
* Registers a {@link Filter} to exist before a particular {@link Filter} that is
167-
* already registered.
168-
* @param filter the {@link Filter} to register
169-
* @param beforeFilter the {@link Filter} that is already registered and that
170-
* {@code filter} should be placed before.
171-
*/
172-
public void registerBefore(Class<? extends Filter> filter,
173-
Class<? extends Filter> beforeFilter) {
174-
Integer position = getOrder(beforeFilter);
175-
if (position == null) {
176-
throw new IllegalArgumentException(
177-
"Cannot register after unregistered Filter " + beforeFilter);
178-
}
179-
180-
put(filter, position - 1);
181-
}
182-
183111
private void put(Class<? extends Filter> filter, int position) {
184112
String className = filter.getName();
185113
filterToOrder.put(className, position);
@@ -192,7 +120,7 @@ private void put(Class<? extends Filter> filter, int position) {
192120
* @param clazz the {@link Filter} class to determine the sort order
193121
* @return the sort order or null if not defined
194122
*/
195-
private Integer getOrder(Class<?> clazz) {
123+
Integer getOrder(Class<?> clazz) {
196124
while (clazz != null) {
197125
Integer result = filterToOrder.get(clazz.getName());
198126
if (result != null) {

config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java

+64-17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
package org.springframework.security.config.annotation.web.builders;
1717

1818
import org.springframework.context.ApplicationContext;
19+
import org.springframework.core.OrderComparator;
20+
import org.springframework.core.Ordered;
1921
import org.springframework.http.HttpMethod;
2022
import org.springframework.security.authentication.AuthenticationManager;
2123
import org.springframework.security.authentication.AuthenticationProvider;
@@ -78,10 +80,16 @@
7880
import org.springframework.web.filter.CorsFilter;
7981
import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
8082

83+
import java.io.IOException;
8184
import java.util.ArrayList;
8285
import java.util.List;
8386
import java.util.Map;
87+
8488
import javax.servlet.Filter;
89+
import javax.servlet.FilterChain;
90+
import javax.servlet.ServletException;
91+
import javax.servlet.ServletRequest;
92+
import javax.servlet.ServletResponse;
8593
import javax.servlet.http.HttpServletRequest;
8694

8795
/**
@@ -125,9 +133,9 @@ public final class HttpSecurity extends
125133
implements SecurityBuilder<DefaultSecurityFilterChain>,
126134
HttpSecurityBuilder<HttpSecurity> {
127135
private final RequestMatcherConfigurer requestMatcherConfigurer;
128-
private List<Filter> filters = new ArrayList<>();
136+
private List<OrderedFilter> filters = new ArrayList<>();
129137
private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE;
130-
private FilterComparator comparator = new FilterComparator();
138+
private FilterOrderRegistration filterOrders = new FilterOrderRegistration();
131139

132140
/**
133141
* Creates a new instance
@@ -2532,8 +2540,12 @@ protected void beforeConfigure() throws Exception {
25322540

25332541
@Override
25342542
protected DefaultSecurityFilterChain performBuild() {
2535-
filters.sort(comparator);
2536-
return new DefaultSecurityFilterChain(requestMatcher, filters);
2543+
this.filters.sort(OrderComparator.INSTANCE);
2544+
List<Filter> sortedFilters = new ArrayList<>(this.filters.size());
2545+
for (Filter filter : this.filters) {
2546+
sortedFilters.add(((OrderedFilter) filter).filter);
2547+
}
2548+
return new DefaultSecurityFilterChain(this.requestMatcher, sortedFilters);
25372549
}
25382550

25392551
/*
@@ -2574,8 +2586,7 @@ private AuthenticationManagerBuilder getAuthenticationRegistry() {
25742586
* .servlet.Filter, java.lang.Class)
25752587
*/
25762588
public HttpSecurity addFilterAfter(Filter filter, Class<? extends Filter> afterFilter) {
2577-
comparator.registerAfter(filter.getClass(), afterFilter);
2578-
return addFilter(filter);
2589+
return addFilterAtOffsetOf(filter, 1, afterFilter);
25792590
}
25802591

25812592
/*
@@ -2587,8 +2598,13 @@ public HttpSecurity addFilterAfter(Filter filter, Class<? extends Filter> afterF
25872598
*/
25882599
public HttpSecurity addFilterBefore(Filter filter,
25892600
Class<? extends Filter> beforeFilter) {
2590-
comparator.registerBefore(filter.getClass(), beforeFilter);
2591-
return addFilter(filter);
2601+
return addFilterAtOffsetOf(filter, -1, beforeFilter);
2602+
}
2603+
2604+
private HttpSecurity addFilterAtOffsetOf(Filter filter, int offset, Class<? extends Filter> registeredFilter) {
2605+
int order = this.filterOrders.getOrder(registeredFilter) + offset;
2606+
this.filters.add(new OrderedFilter(filter, order));
2607+
return this;
25922608
}
25932609

25942610
/*
@@ -2599,14 +2615,12 @@ public HttpSecurity addFilterBefore(Filter filter,
25992615
* servlet.Filter)
26002616
*/
26012617
public HttpSecurity addFilter(Filter filter) {
2602-
Class<? extends Filter> filterClass = filter.getClass();
2603-
if (!comparator.isRegistered(filterClass)) {
2604-
throw new IllegalArgumentException(
2605-
"The Filter class "
2606-
+ filterClass.getName()
2607-
+ " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead.");
2618+
Integer order = this.filterOrders.getOrder(filter.getClass());
2619+
if (order == null) {
2620+
throw new IllegalArgumentException("The Filter class " + filter.getClass().getName()
2621+
+ " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead.");
26082622
}
2609-
this.filters.add(filter);
2623+
this.filters.add(new OrderedFilter(filter, order));
26102624
return this;
26112625
}
26122626

@@ -2630,8 +2644,7 @@ public HttpSecurity addFilter(Filter filter) {
26302644
* @return the {@link HttpSecurity} for further customizations
26312645
*/
26322646
public HttpSecurity addFilterAt(Filter filter, Class<? extends Filter> atFilter) {
2633-
this.comparator.registerAt(filter.getClass(), atFilter);
2634-
return addFilter(filter);
2647+
return addFilterAtOffsetOf(filter, 0, atFilter);
26352648
}
26362649

26372650
/**
@@ -3027,4 +3040,38 @@ private <C extends SecurityConfigurerAdapter<DefaultSecurityFilterChain, HttpSec
30273040
}
30283041
return apply(configurer);
30293042
}
3043+
3044+
/*
3045+
* A Filter that implements Ordered to be sorted. After sorting occurs, the original
3046+
* filter is what is used by FilterChainProxy
3047+
*/
3048+
private static final class OrderedFilter implements Ordered, Filter {
3049+
3050+
private final Filter filter;
3051+
3052+
private final int order;
3053+
3054+
private OrderedFilter(Filter filter, int order) {
3055+
this.filter = filter;
3056+
this.order = order;
3057+
}
3058+
3059+
@Override
3060+
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
3061+
throws IOException, ServletException {
3062+
this.filter.doFilter(servletRequest, servletResponse, filterChain);
3063+
}
3064+
3065+
@Override
3066+
public int getOrder() {
3067+
return this.order;
3068+
}
3069+
3070+
@Override
3071+
public String toString() {
3072+
return "OrderedFilter{" + "filter=" + this.filter + ", order=" + this.order + '}';
3073+
}
3074+
3075+
}
3076+
30303077
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright 2002-2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config.annotation.web.builders;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
import java.util.stream.Collectors;
22+
23+
import javax.servlet.Filter;
24+
import javax.servlet.FilterChain;
25+
import javax.servlet.ServletException;
26+
import javax.servlet.ServletRequest;
27+
import javax.servlet.ServletResponse;
28+
29+
import org.assertj.core.api.ListAssert;
30+
import org.junit.Rule;
31+
import org.junit.Test;
32+
33+
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
34+
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
35+
import org.springframework.security.config.test.SpringTestRule;
36+
import org.springframework.security.web.FilterChainProxy;
37+
import org.springframework.security.web.access.ExceptionTranslationFilter;
38+
import org.springframework.security.web.access.channel.ChannelProcessingFilter;
39+
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
40+
import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
41+
42+
import static org.assertj.core.api.Assertions.assertThat;
43+
44+
public class HttpSecurityAddFilterTest {
45+
46+
@Rule
47+
public final SpringTestRule spring = new SpringTestRule();
48+
49+
@Test
50+
public void addFilterAfterWhenSameFilterDifferentPlacesThenOrderCorrect() {
51+
this.spring.register(MyFilterMultipleAfterConfig.class).autowire();
52+
53+
assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
54+
ExceptionTranslationFilter.class, MyFilter.class);
55+
}
56+
57+
@Test
58+
public void addFilterBeforeWhenSameFilterDifferentPlacesThenOrderCorrect() {
59+
this.spring.register(MyFilterMultipleBeforeConfig.class).autowire();
60+
61+
assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class,
62+
ExceptionTranslationFilter.class);
63+
}
64+
65+
@Test
66+
public void addFilterAtWhenSameFilterDifferentPlacesThenOrderCorrect() {
67+
this.spring.register(MyFilterMultipleAtConfig.class).autowire();
68+
69+
assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class,
70+
ExceptionTranslationFilter.class);
71+
}
72+
73+
private ListAssert<Class<?>> assertThatFilters() {
74+
FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class);
75+
List<Class<?>> filters = filterChain.getFilters("/").stream().map(Object::getClass)
76+
.collect(Collectors.toList());
77+
return assertThat(filters);
78+
}
79+
80+
public static class MyFilter implements Filter {
81+
82+
@Override
83+
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
84+
throws IOException, ServletException {
85+
filterChain.doFilter(servletRequest, servletResponse);
86+
}
87+
88+
}
89+
90+
@EnableWebSecurity
91+
static class MyFilterMultipleAfterConfig extends WebSecurityConfigurerAdapter {
92+
93+
@Override
94+
protected void configure(HttpSecurity http) throws Exception {
95+
// @formatter:off
96+
http
97+
.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
98+
.addFilterAfter(new MyFilter(), ExceptionTranslationFilter.class);
99+
// @formatter:on
100+
}
101+
102+
}
103+
104+
@EnableWebSecurity
105+
static class MyFilterMultipleBeforeConfig extends WebSecurityConfigurerAdapter {
106+
107+
@Override
108+
protected void configure(HttpSecurity http) throws Exception {
109+
// @formatter:off
110+
http
111+
.addFilterBefore(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
112+
.addFilterBefore(new MyFilter(), ExceptionTranslationFilter.class);
113+
// @formatter:on
114+
}
115+
116+
}
117+
118+
@EnableWebSecurity
119+
static class MyFilterMultipleAtConfig extends WebSecurityConfigurerAdapter {
120+
121+
@Override
122+
protected void configure(HttpSecurity http) throws Exception {
123+
// @formatter:off
124+
http
125+
.addFilterAt(new MyFilter(), ChannelProcessingFilter.class)
126+
.addFilterAt(new MyFilter(), UsernamePasswordAuthenticationFilter.class);
127+
// @formatter:on
128+
}
129+
130+
}
131+
132+
}

0 commit comments

Comments
 (0)