Skip to content

Fix #8817 #8820 Improve CookieRequestCache #8818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.WebUtils;
Expand All @@ -29,6 +31,7 @@
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Base64;
import java.util.HashMap;


/**
Expand All @@ -52,7 +55,7 @@ public void saveRequest(HttpServletRequest request, HttpServletResponse response
Cookie savedCookie = new Cookie(COOKIE_NAME, encodeCookie(redirectUrl));
savedCookie.setMaxAge(COOKIE_MAX_AGE);
savedCookie.setSecure(request.isSecure());
savedCookie.setPath(request.getContextPath());
savedCookie.setPath(getCookiePath(request));
savedCookie.setHttpOnly(true);

response.addCookie(savedCookie);
Expand All @@ -65,7 +68,7 @@ public void saveRequest(HttpServletRequest request, HttpServletResponse response
public SavedRequest getRequest(HttpServletRequest request, HttpServletResponse response) {
Cookie savedRequestCookie = WebUtils.getCookie(request, COOKIE_NAME);
if (savedRequestCookie != null) {
String originalURI = decodeCookie(savedRequestCookie.getValue());
final String originalURI = decodeCookie(savedRequestCookie.getValue());
UriComponents uriComponents = UriComponentsBuilder.fromUriString(originalURI).build();
DefaultSavedRequest.Builder builder = new DefaultSavedRequest.Builder();

Expand All @@ -77,32 +80,44 @@ public SavedRequest getRequest(HttpServletRequest request, HttpServletResponse r
port = 80;
}
}

final MultiValueMap<String, String> queryParams = uriComponents.getQueryParams();

if (!queryParams.isEmpty()) {
final HashMap<String, String[]> parameters = new HashMap<>(queryParams.size());
queryParams.forEach((key, value) -> parameters.put(key, value.toArray(new String[]{})));
builder.setParameters(parameters);
}

return builder.setScheme(uriComponents.getScheme())
.setServerName(uriComponents.getHost())
.setRequestURI(uriComponents.getPath())
.setQueryString(uriComponents.getQuery())
.setServerPort(port)
.setMethod(request.getMethod())
.build();
}
return null;
}

@Override
public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) {
SavedRequest savedRequest = getRequest(request, response);
if (savedRequest != null) {
removeRequest(request, response);
return new SavedRequestAwareWrapper(savedRequest, request);
SavedRequest saved = this.getRequest(request, response);
if (!this.matchesSavedRequest(request, saved)) {
this.logger.debug("saved request doesn't match");
return null;
} else {
this.removeRequest(request, response);
return new SavedRequestAwareWrapper(saved, request);
}
return null;
}

@Override
public void removeRequest(HttpServletRequest request, HttpServletResponse response) {
Cookie removeSavedRequestCookie = new Cookie(COOKIE_NAME, "");
removeSavedRequestCookie.setSecure(request.isSecure());
removeSavedRequestCookie.setHttpOnly(true);
removeSavedRequestCookie.setPath(request.getContextPath());
removeSavedRequestCookie.setPath(getCookiePath(request));
removeSavedRequestCookie.setMaxAge(0);
response.addCookie(removeSavedRequestCookie);
}
Expand All @@ -115,6 +130,23 @@ private static String decodeCookie(String encodedCookieValue) {
return new String(Base64.getDecoder().decode(encodedCookieValue.getBytes()));
}

private static String getCookiePath(HttpServletRequest request) {
final String contextPath = request.getContextPath();
if (StringUtils.isEmpty(contextPath)) {
return "/";
}
return contextPath;
}

private boolean matchesSavedRequest(HttpServletRequest request, SavedRequest savedRequest) {
if (savedRequest == null) {
return false;
} else {
String currentUrl = UrlUtils.buildFullRequestUrl(request);
return savedRequest.getRedirectUrl().equals(currentUrl);
}
}

/**
* Allows selective use of saved requests for a subset of requests. By default any
* request will be cached by the {@code saveRequest} method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.util.StringUtils;

import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -50,7 +51,7 @@ public void saveRequestWhenMatchesThenSavedRequestInACookieOnResponse() {
assertThat(redirectUrl).isEqualTo("https://abc.com/destination?param1=a&param2=b&param3=1122");

assertThat(savedCookie.getMaxAge()).isEqualTo(-1);
assertThat(savedCookie.getPath()).isEqualTo(request.getContextPath());
assertThat(savedCookie.getPath()).isEqualTo(StringUtils.isEmpty(request.getContextPath()) ? "/" : request.getContextPath());
assertThat(savedCookie.isHttpOnly()).isTrue();
assertThat(savedCookie.getSecure()).isTrue();

Expand Down Expand Up @@ -123,7 +124,8 @@ public void matchingRequestWhenRequestDoesNotContainSavedRequestCookieThenReturn
@Test
public void matchingRequestWhenRequestContainsSavedRequestCookieThenSetsAnExpiredCookieInResponse() {
CookieRequestCache cookieRequestCache = new CookieRequestCache();
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletRequest request = requestToSave();

String redirectUrl = "https://abc.com/destination?param1=a&param2=b&param3=1122";
request.setCookies(new Cookie(DEFAULT_COOKIE_NAME, encodeCookie(redirectUrl)));
MockHttpServletResponse response = new MockHttpServletResponse();
Expand All @@ -135,6 +137,22 @@ public void matchingRequestWhenRequestContainsSavedRequestCookieThenSetsAnExpire
assertThat(expiredCookie.getMaxAge()).isZero();
}

@Test
public void notMatchingRequestWhenRequestNotContainsSavedRequestCookie() {
CookieRequestCache cookieRequestCache = new CookieRequestCache();
MockHttpServletRequest request = requestToSave();

String redirectUrl = "https://abc.com/api";
request.setCookies(new Cookie(DEFAULT_COOKIE_NAME, encodeCookie(redirectUrl)));
MockHttpServletResponse response = new MockHttpServletResponse();

final HttpServletRequest matchingRequest = cookieRequestCache.getMatchingRequest(request, response);
assertThat(matchingRequest).isNull();
Cookie expiredCookie = response.getCookie(DEFAULT_COOKIE_NAME);
assertThat(expiredCookie).isNull();

}

@Test
public void removeRequestWhenInvokedThenSetsAnExpiredCookieOnResponse() {
CookieRequestCache cookieRequestCache = new CookieRequestCache();
Expand Down