Skip to content

Commit 561c786

Browse files
committed
Repair Flaky Tests
The issue turned out to be that OpenSAML first sends two HEAD requests before sending a GET to retrieve the metadata. The way the MockWebServer dispatcher was configured, it would send back the metadata on each request. This created a situation where sockets were being closed by the client before the server had sent all the response, resulting in a broken pipe. The tests would succeed most of the time due to lucky timing between the client closing the socket and the server having sent all of its (unrequested) data. This version sends an expected HEAD response when requested. Issue gh-15395
1 parent e90a6b6 commit 561c786

File tree

2 files changed

+194
-128
lines changed

2 files changed

+194
-128
lines changed

saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepositoryTests.java

+97-64
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,23 @@
2020
import java.io.File;
2121
import java.io.IOException;
2222
import java.io.InputStreamReader;
23+
import java.io.UncheckedIOException;
2324
import java.util.ArrayList;
2425
import java.util.Collection;
2526
import java.util.List;
27+
import java.util.Map;
2628
import java.util.Set;
29+
import java.util.UUID;
30+
import java.util.concurrent.ConcurrentHashMap;
2731
import java.util.stream.Collectors;
2832

2933
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
34+
import okhttp3.mockwebserver.Dispatcher;
3035
import okhttp3.mockwebserver.MockResponse;
3136
import okhttp3.mockwebserver.MockWebServer;
32-
import org.junit.jupiter.api.BeforeEach;
37+
import okhttp3.mockwebserver.RecordedRequest;
38+
import org.junit.jupiter.api.AfterAll;
39+
import org.junit.jupiter.api.BeforeAll;
3340
import org.junit.jupiter.api.Test;
3441
import org.opensaml.core.xml.XMLObject;
3542
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
@@ -68,52 +75,59 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests {
6875
OpenSamlInitializationService.initialize();
6976
}
7077

71-
private String metadata;
78+
private static MetadataDispatcher dispatcher = new MetadataDispatcher()
79+
.addResponse("/entity.xml", readFile("test-metadata.xml"))
80+
.addResponse("/entities.xml", readFile("test-entitiesdescriptor.xml"));
7281

73-
private String entitiesDescriptor;
82+
private static MockWebServer web = new MockWebServer();
7483

75-
@BeforeEach
76-
public void setup() throws Exception {
77-
ClassPathResource resource = new ClassPathResource("test-metadata.xml");
78-
try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) {
79-
this.metadata = reader.lines().collect(Collectors.joining());
84+
private static String readFile(String fileName) {
85+
try {
86+
ClassPathResource resource = new ClassPathResource(fileName);
87+
try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) {
88+
return reader.lines().collect(Collectors.joining());
89+
}
8090
}
81-
resource = new ClassPathResource("test-entitiesdescriptor.xml");
82-
try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) {
83-
this.entitiesDescriptor = reader.lines().collect(Collectors.joining());
91+
catch (IOException ex) {
92+
throw new UncheckedIOException(ex);
8493
}
8594
}
8695

96+
@BeforeAll
97+
public static void start() throws Exception {
98+
web.setDispatcher(dispatcher);
99+
web.start();
100+
}
101+
102+
@AfterAll
103+
public static void shutdown() throws Exception {
104+
web.shutdown();
105+
}
106+
87107
@Test
88108
public void withMetadataUrlLocationWhenResolvableThenFindByEntityIdReturns() throws Exception {
89-
try (MockWebServer server = new MockWebServer()) {
90-
enqueue(server, this.metadata, 3);
91-
AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository
92-
.withTrustedMetadataLocation(server.url("/").toString())
93-
.build();
94-
AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth");
95-
assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth");
96-
assertThat(party.getSingleSignOnServiceLocation())
97-
.isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO");
98-
assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST);
99-
assertThat(party.getVerificationX509Credentials()).hasSize(1);
100-
assertThat(party.getEncryptionX509Credentials()).hasSize(1);
101-
}
109+
AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository
110+
.withTrustedMetadataLocation(web.url("/entity.xml").toString())
111+
.build();
112+
AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth");
113+
assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth");
114+
assertThat(party.getSingleSignOnServiceLocation())
115+
.isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO");
116+
assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST);
117+
assertThat(party.getVerificationX509Credentials()).hasSize(1);
118+
assertThat(party.getEncryptionX509Credentials()).hasSize(1);
102119
}
103120

104121
@Test
105122
public void withMetadataUrlLocationnWhenResolvableThenIteratorReturns() throws Exception {
106-
try (MockWebServer server = new MockWebServer()) {
107-
enqueue(server, this.entitiesDescriptor, 3);
108-
List<AssertingPartyMetadata> parties = new ArrayList<>();
109-
OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(server.url("/").toString())
110-
.build()
111-
.iterator()
112-
.forEachRemaining(parties::add);
113-
assertThat(parties).hasSize(2);
114-
assertThat(parties).extracting(AssertingPartyMetadata::getEntityId)
115-
.contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth");
116-
}
123+
List<AssertingPartyMetadata> parties = new ArrayList<>();
124+
OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(web.url("/entities.xml").toString())
125+
.build()
126+
.iterator()
127+
.forEachRemaining(parties::add);
128+
assertThat(parties).hasSize(2);
129+
assertThat(parties).extracting(AssertingPartyMetadata::getEntityId)
130+
.contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth");
117131
}
118132

119133
@Test
@@ -128,12 +142,10 @@ public void withMetadataUrlLocationWhenUnresolvableThenThrowsSaml2Exception() th
128142

129143
@Test
130144
public void withMetadataUrlLocationWhenMalformedResponseThenSaml2Exception() throws Exception {
131-
try (MockWebServer server = new MockWebServer()) {
132-
enqueue(server, "malformed", 3);
133-
String url = server.url("/").toString();
134-
assertThatExceptionOfType(Saml2Exception.class)
135-
.isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build());
136-
}
145+
dispatcher.addResponse("/malformed", "malformed");
146+
String url = web.url("/malformed").toString();
147+
assertThatExceptionOfType(Saml2Exception.class)
148+
.isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build());
137149
}
138150

139151
@Test
@@ -211,14 +223,13 @@ public void withTrustedMetadataLocationWhenMatchingCredentialsThenVerifiesSignat
211223
String serialized = serialize(descriptor);
212224
Credential credential = TestOpenSamlObjects
213225
.getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID());
214-
try (MockWebServer server = new MockWebServer()) {
215-
enqueue(server, serialized, 3);
216-
AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository
217-
.withTrustedMetadataLocation(server.url("/").toString())
218-
.verificationCredentials((c) -> c.add(credential))
219-
.build();
220-
assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull();
221-
}
226+
String endpoint = "/" + UUID.randomUUID().toString();
227+
dispatcher.addResponse(endpoint, serialized);
228+
AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository
229+
.withTrustedMetadataLocation(web.url(endpoint).toString())
230+
.verificationCredentials((c) -> c.add(credential))
231+
.build();
232+
assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull();
222233
}
223234

224235
@Test
@@ -230,13 +241,12 @@ public void withTrustedMetadataLocationWhenMismatchingCredentialsThenSaml2Except
230241
String serialized = serialize(descriptor);
231242
Credential credential = TestOpenSamlObjects
232243
.getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID());
233-
try (MockWebServer server = new MockWebServer()) {
234-
enqueue(server, serialized, 3);
235-
assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository
236-
.withTrustedMetadataLocation(server.url("/").toString())
237-
.verificationCredentials((c) -> c.add(credential))
238-
.build());
239-
}
244+
String endpoint = "/" + UUID.randomUUID().toString();
245+
dispatcher.addResponse(endpoint, serialized);
246+
assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository
247+
.withTrustedMetadataLocation(web.url(endpoint).toString())
248+
.verificationCredentials((c) -> c.add(credential))
249+
.build());
240250
}
241251

242252
@Test
@@ -326,14 +336,13 @@ public void withMetadataLocationWhenMatchingCredentialsThenVerifiesSignature() t
326336
String serialized = serialize(descriptor);
327337
Credential credential = TestOpenSamlObjects
328338
.getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID());
329-
try (MockWebServer server = new MockWebServer()) {
330-
enqueue(server, serialized, 3);
331-
AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository
332-
.withMetadataLocation(server.url("/").toString())
333-
.verificationCredentials((c) -> c.add(credential))
334-
.build();
335-
assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull();
336-
}
339+
String endpoint = "/" + UUID.randomUUID().toString();
340+
dispatcher.addResponse(endpoint, serialized);
341+
AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository
342+
.withMetadataLocation(web.url(endpoint).toString())
343+
.verificationCredentials((c) -> c.add(credential))
344+
.build();
345+
assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull();
337346
}
338347

339348
private static String serialize(XMLObject object) {
@@ -353,4 +362,28 @@ private static void enqueue(MockWebServer web, String body, int times) {
353362
}
354363
}
355364

365+
private static final class MetadataDispatcher extends Dispatcher {
366+
367+
private final MockResponse head = new MockResponse();
368+
369+
private final Map<String, MockResponse> responses = new ConcurrentHashMap<>();
370+
371+
private MetadataDispatcher() {
372+
}
373+
374+
@Override
375+
public MockResponse dispatch(RecordedRequest request) throws InterruptedException {
376+
if ("HEAD".equals(request.getMethod())) {
377+
return this.head;
378+
}
379+
return this.responses.get(request.getPath());
380+
}
381+
382+
private MetadataDispatcher addResponse(String path, String body) {
383+
this.responses.put(path, new MockResponse().setBody(body).setResponseCode(200));
384+
return this;
385+
}
386+
387+
}
388+
356389
}

0 commit comments

Comments
 (0)