Skip to content

Commit 03c8506

Browse files
authored
Refactor session requester to one channel (#75)
1 parent 0085d1a commit 03c8506

File tree

5 files changed

+217
-127
lines changed

5 files changed

+217
-127
lines changed

src/client/mod.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::borrow::Cow;
44
use std::fmt::Write as _;
55
use std::future::Future;
66
use std::mem::ManuallyDrop;
7+
use std::sync::Arc;
78
use std::time::Duration;
89

910
use const_format::formatcp;
@@ -14,7 +15,7 @@ use thiserror::Error;
1415
use tracing::instrument;
1516

1617
pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
17-
use super::session::{Depot, MarshalledRequest, Session, SessionOperation, WatchReceiver};
18+
use super::session::{Depot, MarshalledRequest, Request, Session, SessionOperation, WatchReceiver};
1819
use crate::acl::{Acl, Acls, AuthUser};
1920
use crate::chroot::{Chroot, ChrootPath, OwnedChroot};
2021
use crate::endpoint::{self, IterableEndpoints};
@@ -221,7 +222,7 @@ pub struct Client {
221222
version: Version,
222223
session: SessionInfo,
223224
session_timeout: Duration,
224-
requester: mpsc::UnboundedSender<SessionOperation>,
225+
requester: Arc<mpsc::UnboundedSender<Request>>,
225226
state_watcher: StateWatcher,
226227
}
227228

@@ -243,7 +244,7 @@ impl Client {
243244
version: Version,
244245
session: SessionInfo,
245246
timeout: Duration,
246-
requester: mpsc::UnboundedSender<SessionOperation>,
247+
requester: Arc<mpsc::UnboundedSender<Request>>,
247248
state_watcher: StateWatcher,
248249
) -> Client {
249250
Client { chroot, version, session, session_timeout: timeout, requester, state_watcher }
@@ -316,9 +317,9 @@ impl Client {
316317

317318
fn send_marshalled_request(&self, request: MarshalledRequest) -> StateReceiver {
318319
let (operation, receiver) = SessionOperation::new_marshalled(request).with_responser();
319-
if let Err(err) = self.requester.unbounded_send(operation) {
320+
if let Err(err) = self.requester.unbounded_send(operation.into()) {
320321
let state = self.state();
321-
err.into_inner().responser.send(Err(state.to_error()));
322+
err.into_inner().into_responser().send(Err(state.to_error()));
322323
}
323324
receiver
324325
}
@@ -1655,7 +1656,9 @@ impl Connector {
16551656
let builder = builder.with_tls(self.tls.take());
16561657
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
16571658
let builder = builder.with_sasl(self.sasl.take());
1658-
let (mut session, state_receiver) = builder.build()?;
1659+
let (sender, receiver) = mpsc::unbounded();
1660+
let sender = Arc::new(sender);
1661+
let mut session = builder.build(Arc::downgrade(&sender))?;
16591662
let mut endpoints = IterableEndpoints::from(endpoints.as_slice());
16601663
endpoints.reset();
16611664
if !self.fail_eagerly {
@@ -1664,10 +1667,9 @@ impl Connector {
16641667
let mut buf = Vec::with_capacity(4096);
16651668
let mut depot = Depot::new();
16661669
let conn = session.start(&mut endpoints, &mut buf, &mut depot).await?;
1667-
let (sender, receiver) = mpsc::unbounded();
16681670
let session_info = session.session.clone();
16691671
let session_timeout = session.session_timeout;
1670-
let mut state_watcher = StateWatcher::new(state_receiver);
1672+
let mut state_watcher = StateWatcher::new(session.subscribe_state());
16711673
// Consume all state changes so far.
16721674
state_watcher.state();
16731675
asyncs::spawn(async move {

src/session/mod.rs

Lines changed: 31 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ mod watch;
77
mod xid;
88

99
use std::pin::pin;
10+
use std::sync::Weak;
1011
use std::time::{Duration, Instant};
1112

1213
use async_io::Timer;
13-
use asyncs::select;
14+
use asyncs::{select, sync};
1415
use futures::channel::mpsc;
1516
use futures::{AsyncWriteExt, StreamExt};
1617
use ignore_result::Ignore;
@@ -25,13 +26,13 @@ pub use self::request::{
2526
MarshalledRequest,
2627
OpStat,
2728
Operation,
29+
Request,
2830
SessionOperation,
2931
StateReceiver,
30-
StateResponser,
3132
};
3233
pub use self::types::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
34+
use self::watch::WatchManager;
3335
pub use self::watch::{OneshotReceiver, PersistentReceiver, WatchReceiver};
34-
use self::watch::{WatchManager, WatcherId};
3536
use crate::deadline::Deadline;
3637
use crate::endpoint::IterableEndpoints;
3738
use crate::error::Error;
@@ -45,22 +46,6 @@ use crate::tls::TlsOptions;
4546
pub const PASSWORD_LEN: usize = 16;
4647
pub const DEFAULT_SESSION_TIMEOUT: Duration = Duration::from_secs(6);
4748

48-
trait RequestOperation {
49-
fn into_responser(self) -> StateResponser;
50-
}
51-
52-
impl RequestOperation for SessionOperation {
53-
fn into_responser(self) -> StateResponser {
54-
self.responser
55-
}
56-
}
57-
58-
impl RequestOperation for (WatcherId, StateResponser) {
59-
fn into_responser(self) -> StateResponser {
60-
self.1
61-
}
62-
}
63-
6449
#[derive(Default)]
6550
pub struct Builder {
6651
#[cfg(feature = "tls")]
@@ -110,7 +95,7 @@ impl Builder {
11095
Self { connection_timeout, ..self }
11196
}
11297

113-
pub fn build(self) -> Result<(Session, asyncs::sync::watch::Receiver<SessionState>), Error> {
98+
pub fn build(self, requester: Weak<mpsc::UnboundedSender<Request>>) -> Result<Session, Error> {
11499
let session = match self.session {
115100
Some(session) => {
116101
if session.is_readonly() {
@@ -136,9 +121,9 @@ impl Builder {
136121
};
137122
#[cfg(not(feature = "tls"))]
138123
let connector = Connector::new();
139-
let (state_sender, state_receiver) = asyncs::sync::watch::channel(SessionState::Disconnected);
124+
let (state_sender, state_receiver) = sync::watch::channel(SessionState::Disconnected);
140125
let now = Instant::now();
141-
let (watch_manager, unwatch_receiver) = WatchManager::new();
126+
let watch_manager = WatchManager::new(requester, state_receiver);
142127
let mut session = Session {
143128
readonly: self.readonly,
144129
detached: self.detached,
@@ -165,11 +150,10 @@ impl Builder {
165150
authes: self.authes,
166151
state_sender,
167152
watch_manager,
168-
unwatch_receiver: Some(unwatch_receiver),
169153
};
170154
let timeout = if self.session_timeout.is_zero() { DEFAULT_SESSION_TIMEOUT } else { self.session_timeout };
171155
session.reset_timeout(timeout);
172-
Ok((session, state_receiver))
156+
Ok(session)
173157
}
174158
}
175159

@@ -199,10 +183,9 @@ pub struct Session {
199183
sasl_session: Option<SaslSession>,
200184

201185
pub authes: Vec<MarshalledRequest>,
202-
state_sender: asyncs::sync::watch::Sender<SessionState>,
186+
state_sender: sync::watch::Sender<SessionState>,
203187

204188
watch_manager: WatchManager,
205-
unwatch_receiver: Option<mpsc::UnboundedReceiver<(WatcherId, StateResponser)>>,
206189
}
207190

208191
impl Session {
@@ -215,27 +198,30 @@ impl Session {
215198
self.readonly && self.session.readonly
216199
}
217200

218-
async fn close_requester<T: RequestOperation>(mut requester: mpsc::UnboundedReceiver<T>, err: &Error) {
201+
async fn close_requester(mut requester: mpsc::UnboundedReceiver<Request>, err: &Error) {
219202
requester.close();
220-
while let Some(operation) = requester.next().await {
221-
let responser = operation.into_responser();
203+
while let Some(request) = requester.next().await {
204+
let responser = request.into_responser();
222205
responser.send(Err(err.clone()));
223206
}
224207
}
225208

209+
pub fn subscribe_state(&self) -> sync::watch::Receiver<SessionState> {
210+
self.state_sender.subscribe()
211+
}
212+
226213
#[instrument(name = "serve", skip_all, fields(session = display(self.session.id)))]
227214
pub async fn serve(
228215
&mut self,
229216
mut endpoints: IterableEndpoints,
230217
conn: Connection,
231218
mut buf: Vec<u8>,
232219
mut depot: Depot,
233-
mut requester: mpsc::UnboundedReceiver<SessionOperation>,
220+
mut requester: mpsc::UnboundedReceiver<Request>,
234221
) {
235-
let mut unwatch_requester = self.unwatch_receiver.take().unwrap();
236222
endpoints.cycle();
237223
endpoints.reset();
238-
self.serve_once(conn, &mut endpoints, &mut buf, &mut depot, &mut requester, &mut unwatch_requester).await;
224+
self.serve_once(conn, &mut endpoints, &mut buf, &mut depot, &mut requester).await;
239225
while !self.session_state.is_terminated() {
240226
let conn = match self.start(&mut endpoints, &mut buf, &mut depot).await {
241227
Err(err) => {
@@ -246,11 +232,10 @@ impl Session {
246232
Ok(conn) => conn,
247233
};
248234
endpoints.reset();
249-
self.serve_once(conn, &mut endpoints, &mut buf, &mut depot, &mut requester, &mut unwatch_requester).await;
235+
self.serve_once(conn, &mut endpoints, &mut buf, &mut depot, &mut requester).await;
250236
}
251237
let err = self.state_error();
252238
Self::close_requester(requester, &err).await;
253-
Self::close_requester(unwatch_requester, &err).await;
254239
depot.terminate(err);
255240
}
256241

@@ -292,10 +277,9 @@ impl Session {
292277
endpoints: &mut IterableEndpoints,
293278
buf: &mut Vec<u8>,
294279
depot: &mut Depot,
295-
requester: &mut mpsc::UnboundedReceiver<SessionOperation>,
296-
unwatch_requester: &mut mpsc::UnboundedReceiver<(WatcherId, StateResponser)>,
280+
requester: &mut mpsc::UnboundedReceiver<Request>,
297281
) {
298-
let err = self.serve_session(endpoints, &mut conn, buf, depot, requester, unwatch_requester).await.unwrap_err();
282+
let err = self.serve_session(endpoints, &mut conn, buf, depot, requester).await.unwrap_err();
299283
self.resolve_serve_error(&err);
300284
info!("enter state {} due to {}", self.session_state, err);
301285
depot.error(&err);
@@ -550,8 +534,7 @@ impl Session {
550534
conn: &mut Connection,
551535
buf: &mut Vec<u8>,
552536
depot: &mut Depot,
553-
requester: &mut mpsc::UnboundedReceiver<SessionOperation>,
554-
unwatch_requester: &mut mpsc::UnboundedReceiver<(WatcherId, StateResponser)>,
537+
requester: &mut mpsc::UnboundedReceiver<Request>,
555538
) -> Result<(), Error> {
556539
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
557540
self.sasl_session.take();
@@ -577,20 +560,19 @@ impl Session {
577560
r?;
578561
self.last_send = Instant::now();
579562
},
580-
r = requester.next(), if !channel_halted => {
581-
let operation = if let Some(operation) = r {
582-
operation
583-
} else {
563+
r = requester.next(), if !channel_halted => match r {
564+
None => {
584565
if !self.detached {
585566
depot.push_session(SessionOperation::new_without_body(OpCode::CloseSession));
586567
}
587568
channel_halted = true;
588-
continue;
589-
};
590-
depot.push_session(operation);
591-
},
592-
r = unwatch_requester.next() => if let Some((watcher_id, responser)) = r {
593-
self.watch_manager.remove_watcher(watcher_id, responser, depot);
569+
}
570+
Some(Request::Session(operation)) => depot.push_session(operation),
571+
Some(Request::RemoveWatcher {
572+
id, responser
573+
}) => {
574+
self.watch_manager.remove_watcher(id, responser, depot);
575+
}
594576
},
595577
now = tick.as_mut() => {
596578
if now >= self.last_recv + self.connector.timeout() {

src/session/request.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use futures::channel::oneshot;
77
use ignore_result::Ignore;
88

99
use super::types::WatchMode;
10-
use super::watch::WatchReceiver;
10+
use super::watch::{WatchReceiver, WatcherId};
1111
use crate::error::Error;
1212
use crate::proto::{self, AddWatchMode, ConnectRequest, OpCode, RequestHeader};
1313
use crate::record::{self, Record, StaticRecord};
@@ -126,6 +126,26 @@ pub enum Operation {
126126
Session(SessionOperation),
127127
}
128128

129+
pub enum Request {
130+
Session(SessionOperation),
131+
RemoveWatcher { id: WatcherId, responser: StateResponser },
132+
}
133+
134+
impl From<SessionOperation> for Request {
135+
fn from(operation: SessionOperation) -> Self {
136+
Self::Session(operation)
137+
}
138+
}
139+
140+
impl Request {
141+
pub fn into_responser(self) -> StateResponser {
142+
match self {
143+
Self::Session(operation) => operation.responser,
144+
Self::RemoveWatcher { responser, .. } => responser,
145+
}
146+
}
147+
}
148+
129149
impl Operation {
130150
pub fn get_data(&self) -> &[u8] {
131151
match self {

0 commit comments

Comments
 (0)