1313 * See the License for the specific language governing permissions and
1414 * limitations under the License.
1515 */
16+
1617package org .springframework .security .web .csrf ;
1718
1819import java .io .IOException ;
20+ import java .security .MessageDigest ;
1921import java .util .Arrays ;
2022import java .util .HashSet ;
2123
2830import org .apache .commons .logging .Log ;
2931import org .apache .commons .logging .LogFactory ;
3032
33+ import org .springframework .core .log .LogMessage ;
34+ import org .springframework .security .access .AccessDeniedException ;
35+ import org .springframework .security .crypto .codec .Utf8 ;
3136import org .springframework .security .web .access .AccessDeniedHandler ;
3237import org .springframework .security .web .access .AccessDeniedHandlerImpl ;
3338import org .springframework .security .web .util .UrlUtils ;
3439import org .springframework .security .web .util .matcher .RequestMatcher ;
3540import org .springframework .util .Assert ;
3641import org .springframework .web .filter .OncePerRequestFilter ;
3742
38- import static java .lang .Boolean .TRUE ;
39-
4043/**
4144 * <p>
4245 * Applies
5861 * @since 3.2
5962 */
6063public final class CsrfFilter extends OncePerRequestFilter {
64+
6165 /**
6266 * The default {@link RequestMatcher} that indicates if CSRF protection is required or
6367 * not. The default is to ignore GET, HEAD, TRACE, OPTIONS and process all other
@@ -66,18 +70,21 @@ public final class CsrfFilter extends OncePerRequestFilter {
6670 public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher ();
6771
6872 /**
69- * The attribute name to use when marking a given request as one that should not be filtered.
73+ * The attribute name to use when marking a given request as one that should not be
74+ * filtered.
7075 *
71- * To use, set the attribute on your {@link HttpServletRequest}:
72- * <pre>
76+ * To use, set the attribute on your {@link HttpServletRequest}: <pre>
7377 * CsrfFilter.skipRequest(request);
7478 * </pre>
7579 */
7680 private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter .class .getName ();
7781
7882 private final Log logger = LogFactory .getLog (getClass ());
83+
7984 private final CsrfTokenRepository tokenRepository ;
85+
8086 private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER ;
87+
8188 private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl ();
8289
8390 public CsrfFilter (CsrfTokenRepository csrfTokenRepository ) {
@@ -87,62 +94,46 @@ public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
8794
8895 @ Override
8996 protected boolean shouldNotFilter (HttpServletRequest request ) throws ServletException {
90- return TRUE .equals (request .getAttribute (SHOULD_NOT_FILTER ));
97+ return Boolean . TRUE .equals (request .getAttribute (SHOULD_NOT_FILTER ));
9198 }
9299
93- /*
94- * (non-Javadoc)
95- *
96- * @see
97- * org.springframework.web.filter.OncePerRequestFilter#doFilterInternal(javax.servlet
98- * .http.HttpServletRequest, javax.servlet.http.HttpServletResponse,
99- * javax.servlet.FilterChain)
100- */
101100 @ Override
102- protected void doFilterInternal (HttpServletRequest request ,
103- HttpServletResponse response , FilterChain filterChain )
104- throws ServletException , IOException {
101+ protected void doFilterInternal (HttpServletRequest request , HttpServletResponse response , FilterChain filterChain )
102+ throws ServletException , IOException {
105103 request .setAttribute (HttpServletResponse .class .getName (), response );
106-
107104 CsrfToken csrfToken = this .tokenRepository .loadToken (request );
108- final boolean missingToken = csrfToken == null ;
105+ boolean missingToken = ( csrfToken == null ) ;
109106 if (missingToken ) {
110107 csrfToken = this .tokenRepository .generateToken (request );
111108 this .tokenRepository .saveToken (csrfToken , request , response );
112109 }
113110 request .setAttribute (CsrfToken .class .getName (), csrfToken );
114111 request .setAttribute (csrfToken .getParameterName (), csrfToken );
115-
116112 if (!this .requireCsrfProtectionMatcher .matches (request )) {
113+ if (this .logger .isTraceEnabled ()) {
114+ this .logger .trace ("Did not protect against CSRF since request did not match "
115+ + this .requireCsrfProtectionMatcher );
116+ }
117117 filterChain .doFilter (request , response );
118118 return ;
119119 }
120-
121120 String actualToken = request .getHeader (csrfToken .getHeaderName ());
122121 if (actualToken == null ) {
123122 actualToken = request .getParameter (csrfToken .getParameterName ());
124123 }
125- if (!csrfToken .getToken ().equals (actualToken )) {
126- if (this .logger .isDebugEnabled ()) {
127- this .logger .debug ("Invalid CSRF token found for "
128- + UrlUtils .buildFullRequestUrl (request ));
129- }
130- if (missingToken ) {
131- this .accessDeniedHandler .handle (request , response ,
132- new MissingCsrfTokenException (actualToken ));
133- }
134- else {
135- this .accessDeniedHandler .handle (request , response ,
136- new InvalidCsrfTokenException (csrfToken , actualToken ));
137- }
124+ if (!equalsConstantTime (csrfToken .getToken (), actualToken )) {
125+ this .logger .debug (
126+ LogMessage .of (() -> "Invalid CSRF token found for " + UrlUtils .buildFullRequestUrl (request )));
127+ AccessDeniedException exception = (!missingToken ) ? new InvalidCsrfTokenException (csrfToken , actualToken )
128+ : new MissingCsrfTokenException (actualToken );
129+ this .accessDeniedHandler .handle (request , response , exception );
138130 return ;
139131 }
140-
141132 filterChain .doFilter (request , response );
142133 }
143134
144135 public static void skipRequest (HttpServletRequest request ) {
145- request .setAttribute (SHOULD_NOT_FILTER , TRUE );
136+ request .setAttribute (SHOULD_NOT_FILTER , Boolean . TRUE );
146137 }
147138
148139 /**
@@ -154,14 +145,11 @@ public static void skipRequest(HttpServletRequest request) {
154145 * The default is to apply CSRF protection for any HTTP method other than GET, HEAD,
155146 * TRACE, OPTIONS.
156147 * </p>
157- *
158148 * @param requireCsrfProtectionMatcher the {@link RequestMatcher} used to determine if
159149 * CSRF protection should be applied.
160150 */
161- public void setRequireCsrfProtectionMatcher (
162- RequestMatcher requireCsrfProtectionMatcher ) {
163- Assert .notNull (requireCsrfProtectionMatcher ,
164- "requireCsrfProtectionMatcher cannot be null" );
151+ public void setRequireCsrfProtectionMatcher (RequestMatcher requireCsrfProtectionMatcher ) {
152+ Assert .notNull (requireCsrfProtectionMatcher , "requireCsrfProtectionMatcher cannot be null" );
165153 this .requireCsrfProtectionMatcher = requireCsrfProtectionMatcher ;
166154 }
167155
@@ -172,28 +160,45 @@ public void setRequireCsrfProtectionMatcher(
172160 * <p>
173161 * The default is to use AccessDeniedHandlerImpl with no arguments.
174162 * </p>
175- *
176163 * @param accessDeniedHandler the {@link AccessDeniedHandler} to use
177164 */
178165 public void setAccessDeniedHandler (AccessDeniedHandler accessDeniedHandler ) {
179166 Assert .notNull (accessDeniedHandler , "accessDeniedHandler cannot be null" );
180167 this .accessDeniedHandler = accessDeniedHandler ;
181168 }
182169
170+ /**
171+ * Constant time comparison to prevent against timing attacks.
172+ * @param expected
173+ * @param actual
174+ * @return
175+ */
176+ private static boolean equalsConstantTime (String expected , String actual ) {
177+ byte [] expectedBytes = bytesUtf8 (expected );
178+ byte [] actualBytes = bytesUtf8 (actual );
179+ return MessageDigest .isEqual (expectedBytes , actualBytes );
180+ }
181+
182+ private static byte [] bytesUtf8 (String s ) {
183+ // need to check if Utf8.encode() runs in constant time (probably not).
184+ // This may leak length of string.
185+ return (s != null ) ? Utf8 .encode (s ) : null ;
186+ }
187+
183188 private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
184- private final HashSet <String > allowedMethods = new HashSet <>(
185- Arrays .asList ("GET" , "HEAD" , "TRACE" , "OPTIONS" ));
186-
187- /*
188- * (non-Javadoc)
189- *
190- * @see
191- * org.springframework.security.web.util.matcher.RequestMatcher#matches(javax.
192- * servlet.http.HttpServletRequest)
193- */
189+
190+ private final HashSet <String > allowedMethods = new HashSet <>(Arrays .asList ("GET" , "HEAD" , "TRACE" , "OPTIONS" ));
191+
194192 @ Override
195193 public boolean matches (HttpServletRequest request ) {
196194 return !this .allowedMethods .contains (request .getMethod ());
197195 }
196+
197+ @ Override
198+ public String toString () {
199+ return "CsrfNotRequired " + this .allowedMethods ;
200+ }
201+
198202 }
203+
199204}
0 commit comments