Skip to content

Commit 5a5bed4

Browse files
committed
Add subscriberContext to PayloadSocketAcceptor delegate.accept
Closes gh-8654
1 parent 8ff3d66 commit 5a5bed4

File tree

3 files changed

+83
-7
lines changed

3 files changed

+83
-7
lines changed

rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ public Mono<RSocket> accept(ConnectionSetupPayload setup, RSocket sendingSocket)
7272
return intercept(setup, dataMimeType, metadataMimeType)
7373
.flatMap(ctx -> this.delegate.accept(setup, sendingSocket)
7474
.map(acceptingSocket -> new PayloadInterceptorRSocket(acceptingSocket, this.interceptors, metadataMimeType, dataMimeType, ctx))
75+
.subscriberContext(ctx)
7576
);
7677
}
7778

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright 2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.rsocket.core;
18+
19+
import io.rsocket.ConnectionSetupPayload;
20+
import io.rsocket.RSocket;
21+
import io.rsocket.SocketAcceptor;
22+
import reactor.core.publisher.Mono;
23+
24+
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
25+
import org.springframework.security.core.context.SecurityContext;
26+
27+
/**
28+
* A {@link SocketAcceptor} that captures the {@link SecurityContext} and then continues with the {@link RSocket}
29+
* @author Rob Winch
30+
*/
31+
class CaptureSecurityContextSocketAcceptor implements SocketAcceptor {
32+
private final RSocket accept;
33+
34+
private SecurityContext securityContext;
35+
36+
CaptureSecurityContextSocketAcceptor(RSocket accept) {
37+
this.accept = accept;
38+
}
39+
40+
@Override
41+
public Mono<RSocket> accept(ConnectionSetupPayload setup, RSocket sendingSocket) {
42+
return ReactiveSecurityContextHolder.getContext()
43+
.doOnNext(securityContext -> this.securityContext = securityContext)
44+
.thenReturn(this.accept);
45+
}
46+
47+
public SecurityContext getSecurityContext() {
48+
return this.securityContext;
49+
}
50+
}

rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java

+32-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
package org.springframework.security.rsocket.core;
1818

19+
import java.util.Arrays;
20+
import java.util.Collections;
21+
import java.util.List;
22+
1923
import io.rsocket.ConnectionSetupPayload;
2024
import io.rsocket.Payload;
2125
import io.rsocket.RSocket;
@@ -27,16 +31,16 @@
2731
import org.mockito.ArgumentCaptor;
2832
import org.mockito.Mock;
2933
import org.mockito.runners.MockitoJUnitRunner;
34+
import reactor.core.publisher.Mono;
35+
import reactor.util.context.Context;
36+
3037
import org.springframework.http.MediaType;
38+
import org.springframework.security.authentication.TestingAuthenticationToken;
39+
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
40+
import org.springframework.security.core.context.SecurityContext;
41+
import org.springframework.security.core.context.SecurityContextImpl;
3142
import org.springframework.security.rsocket.api.PayloadExchange;
3243
import org.springframework.security.rsocket.api.PayloadInterceptor;
33-
import org.springframework.security.rsocket.core.PayloadInterceptorRSocket;
34-
import org.springframework.security.rsocket.core.PayloadSocketAcceptor;
35-
import reactor.core.publisher.Mono;
36-
37-
import java.util.Arrays;
38-
import java.util.Collections;
39-
import java.util.List;
4044

4145
import static org.assertj.core.api.Assertions.assertThat;
4246
import static org.assertj.core.api.Assertions.assertThatCode;
@@ -144,6 +148,27 @@ public void acceptWhenExplicitMimeTypeThenThenOverrideDefault() {
144148
assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
145149
}
146150

151+
152+
@Test
153+
// gh-8654
154+
public void acceptWhenDelegateAcceptRequiresReactiveSecurityContext() {
155+
when(this.setupPayload.metadataMimeType()).thenReturn(MediaType.TEXT_PLAIN_VALUE);
156+
when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE);
157+
SecurityContext expectedSecurityContext = new SecurityContextImpl(new TestingAuthenticationToken("user", "password", "ROLE_USER"));
158+
CaptureSecurityContextSocketAcceptor captureSecurityContext = new CaptureSecurityContextSocketAcceptor(this.rSocket);
159+
PayloadInterceptor authenticateInterceptor = (exchange, chain) -> {
160+
Context withSecurityContext = ReactiveSecurityContextHolder.withSecurityContext(Mono.just(expectedSecurityContext));
161+
return chain.next(exchange)
162+
.subscriberContext(withSecurityContext);
163+
};
164+
List<PayloadInterceptor> interceptors = Arrays.asList(authenticateInterceptor);
165+
this.acceptor = new PayloadSocketAcceptor(captureSecurityContext, interceptors);
166+
167+
this.acceptor.accept(this.setupPayload, this.rSocket).block();
168+
169+
assertThat(captureSecurityContext.getSecurityContext()).isEqualTo(expectedSecurityContext);
170+
}
171+
147172
private PayloadExchange captureExchange() {
148173
when(this.delegate.accept(any(), any())).thenReturn(Mono.just(this.rSocket));
149174
when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty());

0 commit comments

Comments
 (0)