Skip to content

Commit 7b39800

Browse files
committed
Add CachingRelyingPartyRegistrationRepository
Closes gh-15341
1 parent 1e29003 commit 7b39800

File tree

3 files changed

+227
-0
lines changed

3 files changed

+227
-0
lines changed

docs/modules/ROOT/pages/servlet/saml2/login/overview.adoc

+51
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,57 @@ class MyCustomSecurityConfiguration {
588588
A relying party can be multi-tenant by registering more than one relying party in the `RelyingPartyRegistrationRepository`.
589589
====
590590

591+
[[servlet-saml2login-relyingpartyregistrationrepository-caching]]
592+
If you want your metadata to be refreshable on a periodic basis, you can wrap your repository in `CachingRelyingPartyRegistrationRepository` like so:
593+
594+
.Caching Relying Party Registration Repository
595+
[tabs]
596+
======
597+
Java::
598+
+
599+
[source,java,role="primary"]
600+
----
601+
@Configuration
602+
@EnableWebSecurity
603+
public class MyCustomSecurityConfiguration {
604+
@Bean
605+
public RelyingPartyRegistrationRepository registrations(CacheManager cacheManager) {
606+
Supplier<IterableRelyingPartyRegistrationRepository> delegate = () ->
607+
new InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistrations
608+
.fromMetadataLocation("https://idp.example.org/ap/metadata")
609+
.registrationId("ap").build());
610+
CachingRelyingPartyRegistrationRepository registrations =
611+
new CachingRelyingPartyRegistrationRepository(delegate);
612+
registrations.setCache(cacheManager.getCache("my-cache-name"));
613+
return registrations;
614+
}
615+
}
616+
----
617+
618+
Kotlin::
619+
+
620+
[source,kotlin,role="secondary"]
621+
----
622+
@Configuration
623+
@EnableWebSecurity
624+
class MyCustomSecurityConfiguration {
625+
@Bean
626+
fun registrations(cacheManager: CacheManager): RelyingPartyRegistrationRepository {
627+
val delegate = Supplier<IterableRelyingPartyRegistrationRepository> {
628+
InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistrations
629+
.fromMetadataLocation("https://idp.example.org/ap/metadata")
630+
.registrationId("ap").build())
631+
}
632+
val registrations = CachingRelyingPartyRegistrationRepository(delegate)
633+
registrations.setCache(cacheManager.getCache("my-cache-name"))
634+
return registrations
635+
}
636+
}
637+
----
638+
======
639+
640+
In this way, the set of `RelyingPartyRegistration`s will refresh based on {spring-framework-reference-url}integration/cache/store-configuration.html[the cache's eviction schedule].
641+
591642
[[servlet-saml2login-relyingpartyregistration]]
592643
== RelyingPartyRegistration
593644
A {security-api-url}org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.html[`RelyingPartyRegistration`]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.registration;
18+
19+
import java.util.Iterator;
20+
import java.util.Spliterator;
21+
import java.util.concurrent.Callable;
22+
import java.util.function.Consumer;
23+
24+
import org.springframework.cache.Cache;
25+
import org.springframework.cache.concurrent.ConcurrentMapCache;
26+
import org.springframework.util.Assert;
27+
28+
/**
29+
* An {@link IterableRelyingPartyRegistrationRepository} that lazily queries and caches
30+
* metadata from a backing {@link IterableRelyingPartyRegistrationRepository}. Delegates
31+
* caching policies to Spring Cache.
32+
*
33+
* @author Josh Cummings
34+
* @since 6.4
35+
*/
36+
public final class CachingRelyingPartyRegistrationRepository implements IterableRelyingPartyRegistrationRepository {
37+
38+
private final Callable<IterableRelyingPartyRegistrationRepository> registrationLoader;
39+
40+
private Cache cache = new ConcurrentMapCache("registrations");
41+
42+
public CachingRelyingPartyRegistrationRepository(Callable<IterableRelyingPartyRegistrationRepository> loader) {
43+
this.registrationLoader = loader;
44+
}
45+
46+
/**
47+
* {@inheritDoc}
48+
*/
49+
@Override
50+
public Iterator<RelyingPartyRegistration> iterator() {
51+
return registrations().iterator();
52+
}
53+
54+
/**
55+
* {@inheritDoc}
56+
*/
57+
@Override
58+
public RelyingPartyRegistration findByRegistrationId(String registrationId) {
59+
return registrations().findByRegistrationId(registrationId);
60+
}
61+
62+
@Override
63+
public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
64+
return registrations().findUniqueByAssertingPartyEntityId(entityId);
65+
}
66+
67+
@Override
68+
public void forEach(Consumer<? super RelyingPartyRegistration> action) {
69+
registrations().forEach(action);
70+
}
71+
72+
@Override
73+
public Spliterator<RelyingPartyRegistration> spliterator() {
74+
return registrations().spliterator();
75+
}
76+
77+
private IterableRelyingPartyRegistrationRepository registrations() {
78+
return this.cache.get("registrations", this.registrationLoader);
79+
}
80+
81+
/**
82+
* Use this cache for the completed {@link RelyingPartyRegistration} instances.
83+
*
84+
* <p>
85+
* Defaults to {@link ConcurrentMapCache}, meaning that the registrations are cached
86+
* without expiry. To turn off the cache, use
87+
* {@link org.springframework.cache.support.NoOpCache}.
88+
* @param cache the {@link Cache} to use
89+
*/
90+
public void setCache(Cache cache) {
91+
Assert.notNull(cache, "cache cannot be null");
92+
this.cache = cache;
93+
}
94+
95+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.registration;
18+
19+
import java.util.concurrent.Callable;
20+
21+
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.extension.ExtendWith;
23+
import org.mockito.InjectMocks;
24+
import org.mockito.Mock;
25+
import org.mockito.junit.jupiter.MockitoExtension;
26+
27+
import org.springframework.cache.Cache;
28+
29+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
30+
import static org.mockito.BDDMockito.given;
31+
import static org.mockito.Mockito.mock;
32+
import static org.mockito.Mockito.verify;
33+
import static org.mockito.Mockito.verifyNoMoreInteractions;
34+
35+
/**
36+
* Tests for {@link CachingRelyingPartyRegistrationRepository}
37+
*/
38+
@ExtendWith(MockitoExtension.class)
39+
public class CachingRelyingPartyRegistrationRepositoryTests {
40+
41+
@Mock
42+
Callable<Iterable<RelyingPartyRegistration>> callable;
43+
44+
@InjectMocks
45+
CachingRelyingPartyRegistrationRepository registrations;
46+
47+
@Test
48+
public void iteratorWhenResolvableThenPopulatesCache() throws Exception {
49+
given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class));
50+
this.registrations.iterator();
51+
verify(this.callable).call();
52+
this.registrations.iterator();
53+
verifyNoMoreInteractions(this.callable);
54+
}
55+
56+
@Test
57+
public void iteratorWhenExceptionThenPropagates() throws Exception {
58+
given(this.callable.call()).willThrow(IllegalStateException.class);
59+
assertThatExceptionOfType(Cache.ValueRetrievalException.class).isThrownBy(this.registrations::iterator)
60+
.withCauseInstanceOf(IllegalStateException.class);
61+
}
62+
63+
@Test
64+
public void findByRegistrationIdWhenResolvableThenPopulatesCache() throws Exception {
65+
given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class));
66+
this.registrations.findByRegistrationId("id");
67+
verify(this.callable).call();
68+
this.registrations.findByRegistrationId("id");
69+
verifyNoMoreInteractions(this.callable);
70+
}
71+
72+
@Test
73+
public void findUniqueByAssertingPartyEntityIdWhenResolvableThenPopulatesCache() throws Exception {
74+
given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class));
75+
this.registrations.findUniqueByAssertingPartyEntityId("id");
76+
verify(this.callable).call();
77+
this.registrations.findUniqueByAssertingPartyEntityId("id");
78+
verifyNoMoreInteractions(this.callable);
79+
}
80+
81+
}

0 commit comments

Comments
 (0)