Skip to content

Commit d5998d8

Browse files
committed
Add Notifications::next_block_for method
The setup is a little hairy, but seems correct. cc #19
1 parent 75641e1 commit d5998d8

File tree

8 files changed

+141
-46
lines changed

8 files changed

+141
-46
lines changed

Cargo.toml

+2-5
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,18 @@ name = "test"
2020
path = "tests/test.rs"
2121

2222
[features]
23-
default = ["uuid", "time"]
23+
default = ["uuid"]
2424

2525
[dependencies]
2626
phf = "0.1"
2727
phf_mac = "0.1"
2828
openssl = "0.2.1"
29+
time = "0.1"
2930

3031
[dependencies.uuid]
3132
optional = true
3233
version = "0.1"
3334

34-
[dependencies.time]
35-
optional = true
36-
version = "0.1"
37-
3835
[dev-dependencies]
3936
url = "0.2"
4037

README.md

+3-16
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,7 @@ types. The driver currently supports the following conversions:
204204
<td>JSON</td>
205205
</tr>
206206
<tr>
207-
<td>
208-
<a href="https://github.com/rust-lang/time">time::Timespec</a>
209-
(<a href="#optional-features">optional</a>)
210-
</td>
207+
<td>time::Timespec</td>
211208
<td>TIMESTAMP, TIMESTAMP WITH TIME ZONE</td>
212209
</tr>
213210
<tr>
@@ -226,10 +223,7 @@ types. The driver currently supports the following conversions:
226223
<td>INT8RANGE</td>
227224
</tr>
228225
<tr>
229-
<td>
230-
<a href="https://github.com/rust-lang/time">types::range::Range&lt;Timespec&gt;</a>
231-
(<a href="#optional-features">optional</a>)
232-
</td>
226+
<td>types::range::Range&lt;Timespec&gt;</td>
233227
<td>TSRANGE, TSTZRANGE</td>
234228
</tr>
235229
<tr>
@@ -265,10 +259,7 @@ types. The driver currently supports the following conversions:
265259
<td>INT8[], INT8[][], ...</td>
266260
</tr>
267261
<tr>
268-
<td>
269-
<a href="https://github.com/rust-lang/time">types::array::ArrayBase&lt;Option&lt;Timespec&gt;&gt;</a>
270-
(<a href="#optional-features">optional</a>)
271-
</td>
262+
<td>types::array::ArrayBase&lt;Option&lt;Timespec&gt;&gt;</td>
272263
<td>TIMESTAMP[], TIMESTAMPTZ[], TIMESTAMP[][], ...</td>
273264
</tr>
274265
<tr>
@@ -308,10 +299,6 @@ traits.
308299
[UUID](http://www.postgresql.org/docs/9.4/static/datatype-uuid.html) support is
309300
provided optionally by the `uuid` feature. It is enabled by default.
310301

311-
### Time types
312-
[Time](http://www.postgresql.org/docs/9.3/static/datatype-datetime.html)
313-
support is provided optionally by the `time` feature. It is enabled by default.
314-
315302
To disable support for optional features, add `default-features = false` to
316303
your Cargo manifest:
317304

src/io.rs

+21-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use openssl::ssl::{SslStream, MaybeSslStream};
2+
use std::io::BufferedStream;
23
use std::io::net::ip::Port;
34
use std::io::net::tcp::TcpStream;
45
use std::io::net::pipe::UnixStream;
5-
use std::io::IoResult;
6+
use std::io::{IoResult, Stream};
67

78
use {ConnectParams, SslMode, ConnectTarget, ConnectError};
89
use message;
@@ -11,6 +12,23 @@ use message::FrontendMessage::SslRequest;
1112

1213
const DEFAULT_PORT: Port = 5432;
1314

15+
#[doc(hidden)]
16+
pub trait Timeout {
17+
fn set_read_timeout(&mut self, timeout_ms: Option<u64>);
18+
}
19+
20+
impl<S: Stream+Timeout> Timeout for MaybeSslStream<S> {
21+
fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
22+
self.get_mut().set_read_timeout(timeout_ms);
23+
}
24+
}
25+
26+
impl<S: Stream+Timeout> Timeout for BufferedStream<S> {
27+
fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
28+
self.get_mut().set_read_timeout(timeout_ms);
29+
}
30+
}
31+
1432
pub enum InternalStream {
1533
Tcp(TcpStream),
1634
Unix(UnixStream),
@@ -41,9 +59,8 @@ impl Writer for InternalStream {
4159
}
4260
}
4361

44-
impl InternalStream {
45-
#[allow(dead_code)]
46-
pub fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
62+
impl Timeout for InternalStream {
63+
fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
4764
match *self {
4865
InternalStream::Tcp(ref mut s) => s.set_read_timeout(timeout_ms),
4966
InternalStream::Unix(ref mut s) => s.set_read_timeout(timeout_ms),

src/lib.rs

+62-15
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,23 @@ extern crate phf;
6565
extern crate phf_mac;
6666
#[phase(plugin, link)]
6767
extern crate log;
68+
extern crate time;
6869

6970
use url::Url;
7071
use openssl::crypto::hash::{HashType, Hasher};
7172
use openssl::ssl::{SslContext, MaybeSslStream};
7273
use serialize::hex::ToHex;
7374
use std::cell::{Cell, RefCell};
7475
use std::collections::{RingBuf, HashMap};
75-
use std::io::{BufferedStream, IoResult};
76+
use std::io::{BufferedStream, IoResult, IoError, IoErrorKind};
7677
use std::io::net::ip::Port;
7778
use std::iter::IteratorCloneExt;
79+
use std::time::Duration;
7880
use std::mem;
7981
use std::fmt;
8082
use std::result;
8183

82-
use io::InternalStream;
84+
use io::{InternalStream, Timeout};
8385
use message::{FrontendMessage, BackendMessage, RowDescriptionEntry};
8486
use message::FrontendMessage::*;
8587
use message::BackendMessage::*;
@@ -248,8 +250,9 @@ impl<'conn> Notifications<'conn> {
248250
return Ok(notification);
249251
}
250252

251-
check_desync!(self.conn.conn.borrow());
252-
match try!(self.conn.conn.borrow_mut().read_message_with_notification()) {
253+
let mut conn = self.conn.conn.borrow_mut();
254+
check_desync!(conn);
255+
match try!(conn.read_message_with_notification()) {
253256
NotificationResponse { pid, channel, payload } => {
254257
Ok(Notification {
255258
pid: pid,
@@ -260,6 +263,42 @@ impl<'conn> Notifications<'conn> {
260263
_ => unreachable!()
261264
}
262265
}
266+
267+
/// Returns the oldest pending notification
268+
///
269+
/// If no notifications are pending, blocks for up to `timeout` time, after
270+
/// which an `IoError` with the `TimedOut` kind is returned.
271+
pub fn next_block_for(&mut self, timeout: Duration) -> Result<Notification> {
272+
if let Some(notification) = self.next() {
273+
return Ok(notification);
274+
}
275+
276+
let mut conn = self.conn.conn.borrow_mut();
277+
check_desync!(conn);
278+
279+
let end = time::now().to_timespec() + timeout;
280+
loop {
281+
let now = time::now().to_timespec();
282+
conn.stream.set_read_timeout(Some((end - now).num_milliseconds() as u64));
283+
match conn.read_one_message() {
284+
Ok(Some(NotificationResponse { pid, channel, payload })) => {
285+
return Ok(Notification {
286+
pid: pid,
287+
channel: channel,
288+
payload: payload
289+
})
290+
}
291+
Ok(Some(_)) => unreachable!(),
292+
Ok(None) => {}
293+
Err(e @ IoError { kind: IoErrorKind::TimedOut, .. }) => {
294+
conn.desynchronized = false;
295+
return Err(Error::IoError(e));
296+
}
297+
Err(e) => return Err(Error::IoError(e)),
298+
}
299+
300+
}
301+
}
263302
}
264303

265304
/// Contains information necessary to cancel queries for a session
@@ -394,19 +433,27 @@ impl InnerConnection {
394433
Ok(try_desync!(self, self.stream.flush()))
395434
}
396435

397-
fn read_message_with_notification(&mut self) -> IoResult<BackendMessage> {
436+
fn read_one_message(&mut self) -> IoResult<Option<BackendMessage>> {
398437
debug_assert!(!self.desynchronized);
399-
loop {
400-
match try_desync!(self, self.stream.read_message()) {
401-
NoticeResponse { fields } => {
402-
if let Ok(err) = DbError::new_raw(fields) {
403-
self.notice_handler.handle(err);
404-
}
405-
}
406-
ParameterStatus { parameter, value } => {
407-
debug!("Parameter {} = {}", parameter, value)
438+
match try_desync!(self, self.stream.read_message()) {
439+
NoticeResponse { fields } => {
440+
if let Ok(err) = DbError::new_raw(fields) {
441+
self.notice_handler.handle(err);
408442
}
409-
val => return Ok(val)
443+
Ok(None)
444+
}
445+
ParameterStatus { parameter, value } => {
446+
debug!("Parameter {} = {}", parameter, value);
447+
Ok(None)
448+
}
449+
val => Ok(Some(val))
450+
}
451+
}
452+
453+
fn read_message_with_notification(&mut self) -> IoResult<BackendMessage> {
454+
loop {
455+
if let Some(msg) = try!(self.read_one_message()) {
456+
return Ok(msg);
410457
}
411458
}
412459
}

src/message.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::io::{IoResult, IoError, OtherIoError, MemReader};
22
use std::mem;
33

4+
use io::Timeout;
45
use types::Oid;
56

67
use self::BackendMessage::*;
@@ -272,9 +273,17 @@ pub trait ReadMessage {
272273
fn read_message(&mut self) -> IoResult<BackendMessage>;
273274
}
274275

275-
impl<R: Reader> ReadMessage for R {
276+
impl<R: Reader+Timeout> ReadMessage for R {
276277
fn read_message(&mut self) -> IoResult<BackendMessage> {
277-
let ident = try!(self.read_u8());
278+
// The first byte read is a bit complex to make
279+
// Notifications#next_block_for work.
280+
let ident = self.read_u8();
281+
// At this point we've got to turn off any read timeout to prevent
282+
// stream desynchronization. We're assuming that if we've got the first
283+
// byte, there's more stuff to follow.
284+
self.set_read_timeout(None);
285+
let ident = try!(ident);
286+
278287
// subtract size of length value
279288
let len = try!(self.read_be_u32()) as uint - mem::size_of::<i32>();
280289
let mut buf = MemReader::new(try!(self.read_exact(len)));

src/types/mod.rs

-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ pub mod array;
307307
pub mod range;
308308
#[cfg(feature = "uuid")]
309309
mod uuid;
310-
#[cfg(feature = "time")]
311310
mod time;
312311

313312
/// A Postgres OID

src/types/time.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
extern crate time;
2-
3-
use self::time::Timespec;
1+
use time::Timespec;
42
use Result;
53
use types::{RawFromSql, Type, RawToSql};
64
use types::range::{Range, RangeBound, BoundSided, Normalizable};

tests/test.rs

+41
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ extern crate openssl;
88

99
use openssl::ssl::SslContext;
1010
use openssl::ssl::SslMethod::Sslv3;
11+
use std::io::{IoError, IoErrorKind};
1112
use std::io::timer;
1213
use std::time::Duration;
1314

@@ -624,6 +625,46 @@ fn test_notifications_next_block() {
624625
}, or_panic!(notifications.next_block()));
625626
}
626627

628+
#[test]
629+
fn test_notifications_next_block_for() {
630+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
631+
or_panic!(conn.execute("LISTEN test_notifications_next_block_for", &[]));
632+
633+
spawn(proc() {
634+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
635+
timer::sleep(Duration::milliseconds(500));
636+
or_panic!(conn.execute("NOTIFY test_notifications_next_block_for, 'foo'", &[]));
637+
});
638+
639+
let mut notifications = conn.notifications();
640+
check_notification(Notification {
641+
pid: 0,
642+
channel: "test_notifications_next_block_for".to_string(),
643+
payload: "foo".to_string()
644+
}, or_panic!(notifications.next_block_for(Duration::seconds(2))));
645+
}
646+
647+
#[test]
648+
fn test_notifications_next_block_for_timeout() {
649+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
650+
or_panic!(conn.execute("LISTEN test_notifications_next_block_for_timeout", &[]));
651+
652+
spawn(proc() {
653+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
654+
timer::sleep(Duration::seconds(2));
655+
or_panic!(conn.execute("NOTIFY test_notifications_next_block_for_timeout, 'foo'", &[]));
656+
});
657+
658+
let mut notifications = conn.notifications();
659+
match notifications.next_block_for(Duration::milliseconds(500)) {
660+
Err(Error::IoError(IoError { kind: IoErrorKind::TimedOut, .. })) => {},
661+
Err(e) => panic!("Unexpected error {}", e),
662+
Ok(_) => panic!("expected error"),
663+
}
664+
665+
or_panic!(conn.execute("SELECT 1", &[]));
666+
}
667+
627668
#[test]
628669
// This test is pretty sad, but I don't think there's a better way :(
629670
fn test_cancel_query() {

0 commit comments

Comments
 (0)