Skip to content

Commit 5666246

Browse files
committed
PayloadInterceptorRSocket retains all payloads
Flux#skip discards its corresponding elements, meaning that they aren't intended for reuse. When using RSocket's ByteBufPayloads, this means that the bytes are releaseed back into RSocket's pool. Since the downstream request may still need the skipped payload, we should construct the publisher in a different way so as to avoid the preemptive release. Deferring Spring JavaFormat to clarify what changed. Closes gh-9345
1 parent 5243b1b commit 5666246

File tree

2 files changed

+67
-4
lines changed

2 files changed

+67
-4
lines changed

rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019 the original author or authors.
2+
* Copyright 2019-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -92,13 +92,16 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
9292
return Flux.from(payloads).switchOnFirst((signal, innerFlux) -> {
9393
Payload firstPayload = signal.get();
9494
return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload).flatMapMany((context) -> innerFlux
95-
.skip(1).flatMap((p) -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p))
96-
.transform((securedPayloads) -> Flux.concat(Flux.just(firstPayload), securedPayloads))
95+
.index().concatMap((tuple) -> justOrIntercept(tuple.getT1(), tuple.getT2()))
9796
.transform((securedPayloads) -> this.source.requestChannel(securedPayloads))
9897
.subscriberContext(context));
9998
});
10099
}
101100

101+
private Mono<Payload> justOrIntercept(Long index, Payload payload) {
102+
return (index == 0) ? Mono.just(payload) : intercept(PayloadExchangeType.PAYLOAD, payload).thenReturn(payload);
103+
}
104+
102105
@Override
103106
public Mono<Void> metadataPush(Payload payload) {
104107
return intercept(PayloadExchangeType.METADATA_PUSH, payload)

rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java

+61-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019 the original author or authors.
2+
* Copyright 2019-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,10 +19,14 @@
1919
import java.util.Arrays;
2020
import java.util.Collections;
2121
import java.util.List;
22+
import java.util.concurrent.ExecutorService;
23+
import java.util.concurrent.Executors;
2224

2325
import io.rsocket.Payload;
2426
import io.rsocket.RSocket;
2527
import io.rsocket.metadata.WellKnownMimeType;
28+
import io.rsocket.util.ByteBufPayload;
29+
import io.rsocket.util.DefaultPayload;
2630
import io.rsocket.util.RSocketProxy;
2731
import org.junit.Test;
2832
import org.junit.runner.RunWith;
@@ -32,13 +36,17 @@
3236
import org.mockito.runners.MockitoJUnitRunner;
3337
import org.mockito.stubbing.Answer;
3438
import org.reactivestreams.Publisher;
39+
import org.reactivestreams.Subscription;
40+
import reactor.core.CoreSubscriber;
3541
import reactor.core.publisher.Flux;
3642
import reactor.core.publisher.Mono;
3743
import reactor.test.StepVerifier;
3844
import reactor.test.publisher.PublisherProbe;
3945
import reactor.test.publisher.TestPublisher;
46+
import reactor.util.context.Context;
4047

4148
import org.springframework.http.MediaType;
49+
import org.springframework.security.access.AccessDeniedException;
4250
import org.springframework.security.authentication.TestingAuthenticationToken;
4351
import org.springframework.security.core.Authentication;
4452
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
@@ -56,6 +64,7 @@
5664
import static org.mockito.ArgumentMatchers.any;
5765
import static org.mockito.ArgumentMatchers.eq;
5866
import static org.mockito.BDDMockito.given;
67+
import static org.mockito.Mockito.times;
5968
import static org.mockito.Mockito.verify;
6069
import static org.mockito.Mockito.verifyZeroInteractions;
6170

@@ -265,6 +274,57 @@ public void requestChannelWhenInterceptorCompletesThenDelegateSubscribed() {
265274
verify(this.delegate).requestChannel(any());
266275
}
267276

277+
// gh-9345
278+
@Test
279+
public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() {
280+
ExecutorService executors = Executors.newSingleThreadExecutor();
281+
Payload payload = ByteBufPayload.create("data");
282+
Payload payloadTwo = ByteBufPayload.create("moredata");
283+
Payload payloadThree = ByteBufPayload.create("stillmoredata");
284+
Context ctx = Context.empty();
285+
Flux<Payload> payloads = this.payloadResult.flux();
286+
given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty())
287+
.willReturn(Mono.error(() -> new AccessDeniedException("Access Denied")));
288+
given(this.delegate.requestChannel(any())).willAnswer((invocation) -> {
289+
Flux<Payload> input = invocation.getArgument(0);
290+
return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8)
291+
.transform((data) -> Flux.<String>create((emitter) -> {
292+
Runnable run = () -> data.subscribe(new CoreSubscriber<String>() {
293+
@Override
294+
public void onSubscribe(Subscription s) {
295+
s.request(3);
296+
}
297+
298+
@Override
299+
public void onNext(String s) {
300+
emitter.next(s);
301+
}
302+
303+
@Override
304+
public void onError(Throwable t) {
305+
emitter.error(t);
306+
}
307+
308+
@Override
309+
public void onComplete() {
310+
emitter.complete();
311+
}
312+
});
313+
executors.execute(run);
314+
})).map(DefaultPayload::create));
315+
});
316+
PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
317+
Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx);
318+
StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release))
319+
.then(() -> this.payloadResult.assertSubscribers())
320+
.then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree))
321+
.assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8()))
322+
.verifyError(AccessDeniedException.class);
323+
verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any());
324+
assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo);
325+
verify(this.delegate).requestChannel(any());
326+
}
327+
268328
@Test
269329
public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
270330
RuntimeException expected = new RuntimeException("Oops");

0 commit comments

Comments
 (0)