Skip to content

Commit 636a749

Browse files
committed
Pubsub: Keep stream running after sink was closed.
1 parent 2e1b5c1 commit 636a749

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

redis/src/aio/pubsub.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,17 @@ impl PubSubSink {
247247
{
248248
let (sender, mut receiver) = unbounded_channel();
249249
let sink = PipelineSink::new(sink_stream, messages_sender);
250-
let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
251-
.map(Ok)
252-
.forward(sink)
253-
.map(|_| ());
250+
let f = stream::poll_fn(move |cx| {
251+
let res = receiver.poll_recv(cx);
252+
match res {
253+
// We don't want to stop the backing task for the stream, even if the sink was closed.
254+
Poll::Ready(None) => Poll::Pending,
255+
_ => res,
256+
}
257+
})
258+
.map(Ok)
259+
.forward(sink)
260+
.map(|_| ());
254261
(PubSubSink { sender }, f)
255262
}
256263

redis/tests/test_async.rs

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,17 @@ mod basic_async {
686686
let mut publish_conn = ctx.async_connection().await?;
687687
let _: () = publish_conn.publish("phonewave", "banana").await?;
688688

689-
let msg_payload: String = pubsub_stream.next().await.unwrap().get_payload()?;
690-
assert_eq!("banana".to_string(), msg_payload);
689+
let repeats = 6;
690+
for _ in 0..repeats {
691+
let _: () = publish_conn.publish("phonewave", "banana").await?;
692+
}
693+
694+
for _ in 0..repeats {
695+
let message: String =
696+
pubsub_stream.next().await.unwrap().get_payload().unwrap();
697+
698+
assert_eq!("banana".to_string(), message);
699+
}
691700

692701
Ok(())
693702
})
@@ -748,14 +757,18 @@ mod basic_async {
748757
block_on_all(async move {
749758
let (mut sink, mut stream) = ctx.async_pubsub().await?.split();
750759
let mut publish_conn = ctx.async_connection().await?;
751-
let spawned_read = tokio::spawn(async move { stream.next().await });
752760

753761
let _: () = sink.subscribe("phonewave").await?;
754-
let _: () = publish_conn.publish("phonewave", "banana").await?;
762+
let repeats = 6;
763+
for _ in 0..repeats {
764+
let _: () = publish_conn.publish("phonewave", "banana").await?;
765+
}
755766

756-
let message: String = spawned_read.await.unwrap().unwrap().get_payload().unwrap();
767+
for _ in 0..repeats {
768+
let message: String = stream.next().await.unwrap().get_payload().unwrap();
757769

758-
assert_eq!("banana".to_string(), message);
770+
assert_eq!("banana".to_string(), message);
771+
}
759772

760773
Ok(())
761774
})
@@ -768,15 +781,19 @@ mod basic_async {
768781
block_on_all(async move {
769782
let (mut sink, mut stream) = ctx.async_pubsub().await?.split();
770783
let mut publish_conn = ctx.async_connection().await?;
771-
let spawned_read = tokio::spawn(async move { stream.next().await });
772784

773785
let _: () = sink.subscribe("phonewave").await?;
774786
drop(sink);
775-
let _: () = publish_conn.publish("phonewave", "banana").await?;
787+
let repeats = 6;
788+
for _ in 0..repeats {
789+
let _: () = publish_conn.publish("phonewave", "banana").await?;
790+
}
776791

777-
let message: String = spawned_read.await.unwrap().unwrap().get_payload().unwrap();
792+
for _ in 0..repeats {
793+
let message: String = stream.next().await.unwrap().get_payload().unwrap();
778794

779-
assert_eq!("banana".to_string(), message);
795+
assert_eq!("banana".to_string(), message);
796+
}
780797

781798
Ok(())
782799
})
@@ -792,11 +809,18 @@ mod basic_async {
792809

793810
let _: () = pubsub.subscribe("phonewave").await?;
794811
let mut stream = pubsub.into_on_message();
795-
let _: () = publish_conn.publish("phonewave", "banana").await?;
812+
// wait a bit
813+
sleep(Duration::from_secs(2).into()).await;
814+
let repeats = 6;
815+
for _ in 0..repeats {
816+
let _: () = publish_conn.publish("phonewave", "banana").await?;
817+
}
796818

797-
let message: String = stream.next().await.unwrap().get_payload().unwrap();
819+
for _ in 0..repeats {
820+
let message: String = stream.next().await.unwrap().get_payload().unwrap();
798821

799-
assert_eq!("banana".to_string(), message);
822+
assert_eq!("banana".to_string(), message);
823+
}
800824

801825
Ok(())
802826
})

0 commit comments

Comments
 (0)