diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/CookieRequestCache.java b/web/src/main/java/org/springframework/security/web/savedrequest/CookieRequestCache.java index 482b5699f8b..960f8758937 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/CookieRequestCache.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/CookieRequestCache.java @@ -21,6 +21,8 @@ import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.WebUtils; @@ -29,6 +31,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.util.Base64; +import java.util.HashMap; /** @@ -52,7 +55,7 @@ public void saveRequest(HttpServletRequest request, HttpServletResponse response Cookie savedCookie = new Cookie(COOKIE_NAME, encodeCookie(redirectUrl)); savedCookie.setMaxAge(COOKIE_MAX_AGE); savedCookie.setSecure(request.isSecure()); - savedCookie.setPath(request.getContextPath()); + savedCookie.setPath(getCookiePath(request)); savedCookie.setHttpOnly(true); response.addCookie(savedCookie); @@ -65,7 +68,7 @@ public void saveRequest(HttpServletRequest request, HttpServletResponse response public SavedRequest getRequest(HttpServletRequest request, HttpServletResponse response) { Cookie savedRequestCookie = WebUtils.getCookie(request, COOKIE_NAME); if (savedRequestCookie != null) { - String originalURI = decodeCookie(savedRequestCookie.getValue()); + final String originalURI = decodeCookie(savedRequestCookie.getValue()); UriComponents uriComponents = UriComponentsBuilder.fromUriString(originalURI).build(); DefaultSavedRequest.Builder builder = new DefaultSavedRequest.Builder(); @@ -77,11 +80,21 @@ public SavedRequest getRequest(HttpServletRequest request, HttpServletResponse r port = 80; } } + + final MultiValueMap queryParams = uriComponents.getQueryParams(); + + if (!queryParams.isEmpty()) { + final HashMap parameters = new HashMap<>(queryParams.size()); + queryParams.forEach((key, value) -> parameters.put(key, value.toArray(new String[]{}))); + builder.setParameters(parameters); + } + return builder.setScheme(uriComponents.getScheme()) .setServerName(uriComponents.getHost()) .setRequestURI(uriComponents.getPath()) .setQueryString(uriComponents.getQuery()) .setServerPort(port) + .setMethod(request.getMethod()) .build(); } return null; @@ -89,12 +102,14 @@ public SavedRequest getRequest(HttpServletRequest request, HttpServletResponse r @Override public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) { - SavedRequest savedRequest = getRequest(request, response); - if (savedRequest != null) { - removeRequest(request, response); - return new SavedRequestAwareWrapper(savedRequest, request); + SavedRequest saved = this.getRequest(request, response); + if (!this.matchesSavedRequest(request, saved)) { + this.logger.debug("saved request doesn't match"); + return null; + } else { + this.removeRequest(request, response); + return new SavedRequestAwareWrapper(saved, request); } - return null; } @Override @@ -102,7 +117,7 @@ public void removeRequest(HttpServletRequest request, HttpServletResponse respon Cookie removeSavedRequestCookie = new Cookie(COOKIE_NAME, ""); removeSavedRequestCookie.setSecure(request.isSecure()); removeSavedRequestCookie.setHttpOnly(true); - removeSavedRequestCookie.setPath(request.getContextPath()); + removeSavedRequestCookie.setPath(getCookiePath(request)); removeSavedRequestCookie.setMaxAge(0); response.addCookie(removeSavedRequestCookie); } @@ -115,6 +130,23 @@ private static String decodeCookie(String encodedCookieValue) { return new String(Base64.getDecoder().decode(encodedCookieValue.getBytes())); } + private static String getCookiePath(HttpServletRequest request) { + final String contextPath = request.getContextPath(); + if (StringUtils.isEmpty(contextPath)) { + return "/"; + } + return contextPath; + } + + private boolean matchesSavedRequest(HttpServletRequest request, SavedRequest savedRequest) { + if (savedRequest == null) { + return false; + } else { + String currentUrl = UrlUtils.buildFullRequestUrl(request); + return savedRequest.getRedirectUrl().equals(currentUrl); + } + } + /** * Allows selective use of saved requests for a subset of requests. By default any * request will be cached by the {@code saveRequest} method. diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/CookieRequestCacheTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/CookieRequestCacheTests.java index 584ac951d61..cc2714e40a3 100644 --- a/web/src/test/java/org/springframework/security/web/savedrequest/CookieRequestCacheTests.java +++ b/web/src/test/java/org/springframework/security/web/savedrequest/CookieRequestCacheTests.java @@ -18,6 +18,7 @@ import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.util.StringUtils; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; @@ -50,7 +51,7 @@ public void saveRequestWhenMatchesThenSavedRequestInACookieOnResponse() { assertThat(redirectUrl).isEqualTo("https://abc.com/destination?param1=a¶m2=b¶m3=1122"); assertThat(savedCookie.getMaxAge()).isEqualTo(-1); - assertThat(savedCookie.getPath()).isEqualTo(request.getContextPath()); + assertThat(savedCookie.getPath()).isEqualTo(StringUtils.isEmpty(request.getContextPath()) ? "/" : request.getContextPath()); assertThat(savedCookie.isHttpOnly()).isTrue(); assertThat(savedCookie.getSecure()).isTrue(); @@ -123,7 +124,8 @@ public void matchingRequestWhenRequestDoesNotContainSavedRequestCookieThenReturn @Test public void matchingRequestWhenRequestContainsSavedRequestCookieThenSetsAnExpiredCookieInResponse() { CookieRequestCache cookieRequestCache = new CookieRequestCache(); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = requestToSave(); + String redirectUrl = "https://abc.com/destination?param1=a¶m2=b¶m3=1122"; request.setCookies(new Cookie(DEFAULT_COOKIE_NAME, encodeCookie(redirectUrl))); MockHttpServletResponse response = new MockHttpServletResponse(); @@ -135,6 +137,22 @@ public void matchingRequestWhenRequestContainsSavedRequestCookieThenSetsAnExpire assertThat(expiredCookie.getMaxAge()).isZero(); } + @Test + public void notMatchingRequestWhenRequestNotContainsSavedRequestCookie() { + CookieRequestCache cookieRequestCache = new CookieRequestCache(); + MockHttpServletRequest request = requestToSave(); + + String redirectUrl = "https://abc.com/api"; + request.setCookies(new Cookie(DEFAULT_COOKIE_NAME, encodeCookie(redirectUrl))); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final HttpServletRequest matchingRequest = cookieRequestCache.getMatchingRequest(request, response); + assertThat(matchingRequest).isNull(); + Cookie expiredCookie = response.getCookie(DEFAULT_COOKIE_NAME); + assertThat(expiredCookie).isNull(); + + } + @Test public void removeRequestWhenInvokedThenSetsAnExpiredCookieOnResponse() { CookieRequestCache cookieRequestCache = new CookieRequestCache();