19
19
import javax .servlet .http .HttpServletRequest ;
20
20
import javax .servlet .http .HttpServletResponse ;
21
21
22
- import org .springframework .http .HttpMethod ;
23
22
import org .springframework .security .core .Authentication ;
24
23
import org .springframework .security .core .AuthenticationException ;
25
- import org .springframework .security .saml2 .provider .service .authentication .Saml2AuthenticationException ;
26
- import org .springframework .security .saml2 .provider .service .authentication .Saml2AuthenticationToken ;
27
24
import org .springframework .security .saml2 .core .Saml2Error ;
28
- import org .springframework .security .saml2 .provider .service .registration . RelyingPartyRegistration ;
25
+ import org .springframework .security .saml2 .provider .service .authentication . Saml2AuthenticationException ;
29
26
import org .springframework .security .saml2 .provider .service .registration .RelyingPartyRegistrationRepository ;
27
+ import org .springframework .security .saml2 .provider .service .web .DefaultRelyingPartyRegistrationResolver ;
28
+ import org .springframework .security .saml2 .provider .service .web .Saml2AuthenticationTokenConverter ;
30
29
import org .springframework .security .web .authentication .AbstractAuthenticationProcessingFilter ;
30
+ import org .springframework .security .web .authentication .AuthenticationConverter ;
31
31
import org .springframework .security .web .authentication .session .ChangeSessionIdAuthenticationStrategy ;
32
- import org .springframework .security .web .util .matcher .AntPathRequestMatcher ;
33
- import org .springframework .security .web .util .matcher .RequestMatcher ;
34
32
import org .springframework .util .Assert ;
35
33
36
- import static java .nio .charset .StandardCharsets .UTF_8 ;
37
34
import static org .springframework .security .saml2 .core .Saml2ErrorCodes .RELYING_PARTY_REGISTRATION_NOT_FOUND ;
38
- import static org .springframework .security .saml2 .provider .service .registration .RelyingPartyRegistration .withRelyingPartyRegistration ;
39
35
import static org .springframework .util .StringUtils .hasText ;
40
36
41
37
/**
44
40
public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
45
41
46
42
public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/saml2/sso/{registrationId}" ;
47
- private final RequestMatcher matcher ;
48
- private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository ;
43
+ private final AuthenticationConverter authenticationConverter ;
49
44
50
45
/**
51
46
* Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is configured
@@ -64,16 +59,30 @@ public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyin
64
59
public Saml2WebSsoAuthenticationFilter (
65
60
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository ,
66
61
String filterProcessesUrl ) {
67
- super (filterProcessesUrl );
68
- Assert .notNull (relyingPartyRegistrationRepository , "relyingPartyRegistrationRepository cannot be null" );
69
- Assert .hasText (filterProcessesUrl , "filterProcessesUrl must contain a URL pattern" );
62
+ this (new Saml2AuthenticationTokenConverter
63
+ (new DefaultRelyingPartyRegistrationResolver (relyingPartyRegistrationRepository )),
64
+ filterProcessesUrl );
65
+ }
66
+
67
+ /**
68
+ * Creates a {@link Saml2WebSsoAuthenticationFilter} given the provided parameters
69
+ *
70
+ * @param authenticationConverter the strategy for converting an {@link HttpServletRequest}
71
+ * into an {@link Authentication}
72
+ * @param filterProcessingUrl the processing URL, must contain a {registrationId} variable
73
+ * @since 5.4
74
+ */
75
+ public Saml2WebSsoAuthenticationFilter (
76
+ AuthenticationConverter authenticationConverter ,
77
+ String filterProcessingUrl ) {
78
+ super (filterProcessingUrl );
79
+ Assert .notNull (authenticationConverter , "authenticationConverter cannot be null" );
80
+ Assert .hasText (filterProcessingUrl , "filterProcessesUrl must contain a URL pattern" );
70
81
Assert .isTrue (
71
- filterProcessesUrl .contains ("{registrationId}" ),
82
+ filterProcessingUrl .contains ("{registrationId}" ),
72
83
"filterProcessesUrl must contain a {registrationId} match variable"
73
84
);
74
- this .matcher = new AntPathRequestMatcher (filterProcessesUrl );
75
- setRequiresAuthenticationRequestMatcher (this .matcher );
76
- this .relyingPartyRegistrationRepository = relyingPartyRegistrationRepository ;
85
+ this .authenticationConverter = authenticationConverter ;
77
86
setAllowSessionCreation (true );
78
87
setSessionAuthenticationStrategy (new ChangeSessionIdAuthenticationStrategy ());
79
88
}
@@ -86,37 +95,12 @@ protected boolean requiresAuthentication(HttpServletRequest request, HttpServlet
86
95
@ Override
87
96
public Authentication attemptAuthentication (HttpServletRequest request , HttpServletResponse response )
88
97
throws AuthenticationException {
89
- String saml2Response = request .getParameter ("SAMLResponse" );
90
- byte [] b = Saml2Utils .samlDecode (saml2Response );
91
-
92
- String responseXml = inflateIfRequired (request , b );
93
- String registrationId = this .matcher .matcher (request ).getVariables ().get ("registrationId" );
94
- RelyingPartyRegistration rp =
95
- this .relyingPartyRegistrationRepository .findByRegistrationId (registrationId );
96
- if (rp == null ) {
98
+ Authentication authentication = this .authenticationConverter .convert (request );
99
+ if (authentication == null ) {
97
100
Saml2Error saml2Error = new Saml2Error (RELYING_PARTY_REGISTRATION_NOT_FOUND ,
98
- "Relying Party Registration not found with ID: " + registrationId );
101
+ "No relying party registration found" );
99
102
throw new Saml2AuthenticationException (saml2Error );
100
103
}
101
- String applicationUri = Saml2ServletUtils .getApplicationUri (request );
102
- String relyingPartyEntityId = Saml2ServletUtils .resolveUrlTemplate (rp .getEntityId (), applicationUri , rp );
103
- String assertionConsumerServiceLocation = Saml2ServletUtils .resolveUrlTemplate (
104
- rp .getAssertionConsumerServiceLocation (), applicationUri , rp );
105
- RelyingPartyRegistration relyingPartyRegistration = withRelyingPartyRegistration (rp )
106
- .entityId (relyingPartyEntityId )
107
- .assertionConsumerServiceLocation (assertionConsumerServiceLocation )
108
- .build ();
109
- Saml2AuthenticationToken authentication = new Saml2AuthenticationToken (
110
- relyingPartyRegistration , responseXml );
111
104
return getAuthenticationManager ().authenticate (authentication );
112
105
}
113
-
114
- private String inflateIfRequired (HttpServletRequest request , byte [] b ) {
115
- if (HttpMethod .GET .matches (request .getMethod ())) {
116
- return Saml2Utils .samlInflate (b );
117
- }
118
- else {
119
- return new String (b , UTF_8 );
120
- }
121
- }
122
106
}
0 commit comments