Skip to content

Fix Adding Filter Relative to Custom Filter #9902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -114,8 +114,18 @@ final class FilterOrderRegistration {
put(SwitchUserFilter.class, order.next());
}

private void put(Class<? extends Filter> filter, int position) {
/**
* Register a {@link Filter} with its specific position. If the {@link Filter} was
* already registered before, the position previously defined is not going to be
* overriden
* @param filter the {@link Filter} to register
* @param position the position to associate with the {@link Filter}
*/
void put(Class<? extends Filter> filter, int position) {
String className = filter.getName();
if (this.filterToOrder.containsKey(className)) {
return;
}
this.filterToOrder.put(className, position);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2653,6 +2653,7 @@ public HttpSecurity addFilterBefore(Filter filter, Class<? extends Filter> befor
private HttpSecurity addFilterAtOffsetOf(Filter filter, int offset, Class<? extends Filter> registeredFilter) {
int order = this.filterOrders.getOrder(registeredFilter) + offset;
this.filters.add(new OrderedFilter(filter, order));
this.filterOrders.put(filter.getClass(), order);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config.annotation.web.builders;

import java.io.IOException;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;

import org.junit.Test;

import org.springframework.security.web.access.channel.ChannelProcessingFilter;

import static org.assertj.core.api.Assertions.assertThat;

public class FilterOrderRegistrationTests {

private final FilterOrderRegistration filterOrderRegistration = new FilterOrderRegistration();

@Test
public void putWhenNewFilterThenInsertCorrect() {
int position = 153;
this.filterOrderRegistration.put(MyFilter.class, position);
Integer order = this.filterOrderRegistration.getOrder(MyFilter.class);
assertThat(order).isEqualTo(position);
}

@Test
public void putWhenCustomFilterAlreadyExistsThenDoesNotOverride() {
int position = 160;
this.filterOrderRegistration.put(MyFilter.class, position);
this.filterOrderRegistration.put(MyFilter.class, 173);
Integer order = this.filterOrderRegistration.getOrder(MyFilter.class);
assertThat(order).isEqualTo(position);
}

@Test
public void putWhenPredefinedFilterThenDoesNotOverride() {
int position = 100;
Integer predefinedFilterOrderBefore = this.filterOrderRegistration.getOrder(ChannelProcessingFilter.class);
this.filterOrderRegistration.put(MyFilter.class, position);
Integer myFilterOrder = this.filterOrderRegistration.getOrder(MyFilter.class);
Integer predefinedFilterOrderAfter = this.filterOrderRegistration.getOrder(ChannelProcessingFilter.class);
assertThat(myFilterOrder).isEqualTo(position);
assertThat(predefinedFilterOrderAfter).isEqualTo(predefinedFilterOrderBefore).isEqualTo(position);
}

static class MyFilter implements Filter {

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
filterChain.doFilter(servletRequest, servletResponse);
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,14 +30,18 @@
import org.junit.Rule;
import org.junit.Test;

import org.springframework.context.annotation.Bean;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.access.ExceptionTranslationFilter;
import org.springframework.security.web.access.channel.ChannelProcessingFilter;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
import org.springframework.security.web.header.HeaderWriterFilter;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -70,6 +74,46 @@ public void addFilterAtWhenSameFilterDifferentPlacesThenOrderCorrect() {
ExceptionTranslationFilter.class);
}

@Test
public void addFilterAfterWhenAfterCustomFilterThenOrderCorrect() {
this.spring.register(MyOtherFilterRelativeToMyFilterAfterConfig.class).autowire();

assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
MyOtherFilter.class);
}

@Test
public void addFilterBeforeWhenBeforeCustomFilterThenOrderCorrect() {
this.spring.register(MyOtherFilterRelativeToMyFilterBeforeConfig.class).autowire();

assertThatFilters().containsSubsequence(MyOtherFilter.class, MyFilter.class,
WebAsyncManagerIntegrationFilter.class);
}

@Test
public void addFilterAtWhenAtCustomFilterThenOrderCorrect() {
this.spring.register(MyOtherFilterRelativeToMyFilterAtConfig.class).autowire();

assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
MyOtherFilter.class, SecurityContextPersistenceFilter.class);
}

@Test
public void addFilterBeforeWhenCustomFilterDifferentPlacesThenOrderCorrect() {
this.spring.register(MyOtherFilterBeforeToMyFilterMultipleAfterConfig.class).autowire();

assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyOtherFilter.class,
MyFilter.class, ExceptionTranslationFilter.class);
}

@Test
public void addFilterBeforeAndAfterWhenCustomFiltersDifferentPlacesThenOrderCorrect() {
this.spring.register(MyAnotherFilterRelativeToMyCustomFiltersMultipleConfig.class).autowire();

assertThatFilters().containsSubsequence(HeaderWriterFilter.class, MyFilter.class, MyOtherFilter.class,
MyOtherFilter.class, MyAnotherFilter.class, MyFilter.class, ExceptionTranslationFilter.class);
}

private ListAssert<Class<?>> assertThatFilters() {
FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class);
List<Class<?>> filters = filterChain.getFilters("/").stream().map(Object::getClass)
Expand All @@ -87,6 +131,26 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo

}

static class MyOtherFilter implements Filter {

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
filterChain.doFilter(servletRequest, servletResponse);
}

}

static class MyAnotherFilter implements Filter {

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
filterChain.doFilter(servletRequest, servletResponse);
}

}

@EnableWebSecurity
static class MyFilterMultipleAfterConfig extends WebSecurityConfigurerAdapter {

Expand Down Expand Up @@ -129,4 +193,83 @@ protected void configure(HttpSecurity http) throws Exception {

}

@EnableWebSecurity
static class MyOtherFilterRelativeToMyFilterAfterConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterAfter(new MyOtherFilter(), MyFilter.class);
// @formatter:on
return http.build();
}

}

@EnableWebSecurity
static class MyOtherFilterRelativeToMyFilterBeforeConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterBefore(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterBefore(new MyOtherFilter(), MyFilter.class);
// @formatter:on
return http.build();
}

}

@EnableWebSecurity
static class MyOtherFilterRelativeToMyFilterAtConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAt(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterAt(new MyOtherFilter(), MyFilter.class);
// @formatter:on
return http.build();
}

}

@EnableWebSecurity
static class MyOtherFilterBeforeToMyFilterMultipleAfterConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
.addFilterAfter(new MyFilter(), ExceptionTranslationFilter.class)
.addFilterBefore(new MyOtherFilter(), MyFilter.class);
// @formatter:on
return http.build();
}

}

@EnableWebSecurity
static class MyAnotherFilterRelativeToMyCustomFiltersMultipleConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.addFilterAfter(new MyFilter(), HeaderWriterFilter.class)
.addFilterBefore(new MyOtherFilter(), ExceptionTranslationFilter.class)
.addFilterAfter(new MyOtherFilter(), MyFilter.class)
.addFilterAt(new MyAnotherFilter(), MyOtherFilter.class)
.addFilterAfter(new MyFilter(), MyAnotherFilter.class);
// @formatter:on
return http.build();
}

}

}