|
1 | 1 | /*
|
2 |
| - * Copyright 2019 the original author or authors. |
| 2 | + * Copyright 2019-2021 the original author or authors. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
19 | 19 | import java.util.Arrays;
|
20 | 20 | import java.util.Collections;
|
21 | 21 | import java.util.List;
|
| 22 | +import java.util.concurrent.ExecutorService; |
| 23 | +import java.util.concurrent.Executors; |
22 | 24 |
|
23 | 25 | import io.rsocket.Payload;
|
24 | 26 | import io.rsocket.RSocket;
|
25 | 27 | import io.rsocket.metadata.WellKnownMimeType;
|
| 28 | +import io.rsocket.util.ByteBufPayload; |
| 29 | +import io.rsocket.util.DefaultPayload; |
26 | 30 | import io.rsocket.util.RSocketProxy;
|
27 | 31 | import org.junit.Test;
|
28 | 32 | import org.junit.runner.RunWith;
|
|
32 | 36 | import org.mockito.runners.MockitoJUnitRunner;
|
33 | 37 | import org.mockito.stubbing.Answer;
|
34 | 38 | import org.reactivestreams.Publisher;
|
| 39 | +import org.reactivestreams.Subscription; |
| 40 | +import reactor.core.CoreSubscriber; |
35 | 41 | import reactor.core.publisher.Flux;
|
36 | 42 | import reactor.core.publisher.Mono;
|
37 | 43 | import reactor.test.StepVerifier;
|
38 | 44 | import reactor.test.publisher.PublisherProbe;
|
39 | 45 | import reactor.test.publisher.TestPublisher;
|
| 46 | +import reactor.util.context.Context; |
40 | 47 |
|
41 | 48 | import org.springframework.http.MediaType;
|
| 49 | +import org.springframework.security.access.AccessDeniedException; |
42 | 50 | import org.springframework.security.authentication.TestingAuthenticationToken;
|
43 | 51 | import org.springframework.security.core.Authentication;
|
44 | 52 | import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|
56 | 64 | import static org.mockito.ArgumentMatchers.any;
|
57 | 65 | import static org.mockito.ArgumentMatchers.eq;
|
58 | 66 | import static org.mockito.BDDMockito.given;
|
| 67 | +import static org.mockito.Mockito.times; |
59 | 68 | import static org.mockito.Mockito.verify;
|
60 | 69 | import static org.mockito.Mockito.verifyZeroInteractions;
|
61 | 70 |
|
@@ -265,6 +274,57 @@ public void requestChannelWhenInterceptorCompletesThenDelegateSubscribed() {
|
265 | 274 | verify(this.delegate).requestChannel(any());
|
266 | 275 | }
|
267 | 276 |
|
| 277 | + // gh-9345 |
| 278 | + @Test |
| 279 | + public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() { |
| 280 | + ExecutorService executors = Executors.newSingleThreadExecutor(); |
| 281 | + Payload payload = ByteBufPayload.create("data"); |
| 282 | + Payload payloadTwo = ByteBufPayload.create("moredata"); |
| 283 | + Payload payloadThree = ByteBufPayload.create("stillmoredata"); |
| 284 | + Context ctx = Context.empty(); |
| 285 | + Flux<Payload> payloads = this.payloadResult.flux(); |
| 286 | + given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty()) |
| 287 | + .willReturn(Mono.error(() -> new AccessDeniedException("Access Denied"))); |
| 288 | + given(this.delegate.requestChannel(any())).willAnswer((invocation) -> { |
| 289 | + Flux<Payload> input = invocation.getArgument(0); |
| 290 | + return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8) |
| 291 | + .transform((data) -> Flux.<String>create((emitter) -> { |
| 292 | + Runnable run = () -> data.subscribe(new CoreSubscriber<String>() { |
| 293 | + @Override |
| 294 | + public void onSubscribe(Subscription s) { |
| 295 | + s.request(3); |
| 296 | + } |
| 297 | + |
| 298 | + @Override |
| 299 | + public void onNext(String s) { |
| 300 | + emitter.next(s); |
| 301 | + } |
| 302 | + |
| 303 | + @Override |
| 304 | + public void onError(Throwable t) { |
| 305 | + emitter.error(t); |
| 306 | + } |
| 307 | + |
| 308 | + @Override |
| 309 | + public void onComplete() { |
| 310 | + emitter.complete(); |
| 311 | + } |
| 312 | + }); |
| 313 | + executors.execute(run); |
| 314 | + })).map(DefaultPayload::create)); |
| 315 | + }); |
| 316 | + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, |
| 317 | + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx); |
| 318 | + StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release)) |
| 319 | + .then(() -> this.payloadResult.assertSubscribers()) |
| 320 | + .then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree)) |
| 321 | + .assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8())) |
| 322 | + .verifyError(AccessDeniedException.class); |
| 323 | + verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any()); |
| 324 | + assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo); |
| 325 | + verify(this.delegate).requestChannel(any()); |
| 326 | + } |
| 327 | + |
268 | 328 | @Test
|
269 | 329 | public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
|
270 | 330 | RuntimeException expected = new RuntimeException("Oops");
|
|
0 commit comments