Skip to content

Commit 9eccd7d

Browse files
authored
Merge pull request #4170 from weiznich/prevent_protocol_level_size_overflows
Enable some numeric cast releated clippy lints and fix them in the code base
2 parents ae82c4a + fad6b6d commit 9eccd7d

File tree

41 files changed

+351
-149
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+351
-149
lines changed

diesel/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@
244244
clippy::enum_glob_use,
245245
clippy::if_not_else,
246246
clippy::items_after_statements,
247-
clippy::used_underscore_binding
247+
clippy::used_underscore_binding,
248+
clippy::cast_possible_wrap,
249+
clippy::cast_possible_truncation,
250+
clippy::cast_sign_loss
248251
)]
249252
#![deny(unsafe_code)]
250253
#![cfg_attr(test, allow(clippy::map_unwrap_or, clippy::unwrap_used))]

diesel/src/mysql/connection/bind.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ impl Clone for BindData {
178178
// written. At the time of writing this comment, the `BindData::bind_for_truncated_data`
179179
// function is only called by `Binds::populate_dynamic_buffers` which ensures the corresponding
180180
// invariant.
181-
std::slice::from_raw_parts(ptr.as_ptr(), self.length as usize)
181+
std::slice::from_raw_parts(
182+
ptr.as_ptr(),
183+
self.length.try_into().expect("usize is at least 32bit"),
184+
)
182185
};
183186
let mut vec = slice.to_owned();
184187
let ptr = NonNull::new(vec.as_mut_ptr());
@@ -415,7 +418,10 @@ impl BindData {
415418
// written. At the time of writing this comment, the `BindData::bind_for_truncated_data`
416419
// function is only called by `Binds::populate_dynamic_buffers` which ensures the corresponding
417420
// invariant.
418-
std::slice::from_raw_parts(data.as_ptr(), self.length as usize)
421+
std::slice::from_raw_parts(
422+
data.as_ptr(),
423+
self.length.try_into().expect("Usize is at least 32 bit"),
424+
)
419425
};
420426
Some(MysqlValue::new_internal(slice, tpe))
421427
}
@@ -428,7 +434,10 @@ impl BindData {
428434
fn update_buffer_length(&mut self) {
429435
use std::cmp::min;
430436

431-
let actual_bytes_in_buffer = min(self.capacity, self.length as usize);
437+
let actual_bytes_in_buffer = min(
438+
self.capacity,
439+
self.length.try_into().expect("Usize is at least 32 bit"),
440+
);
432441
self.length = actual_bytes_in_buffer as libc::c_ulong;
433442
}
434443

@@ -474,7 +483,8 @@ impl BindData {
474483
self.bytes = None;
475484

476485
let offset = self.capacity;
477-
let truncated_amount = self.length as usize - offset;
486+
let truncated_amount =
487+
usize::try_from(self.length).expect("Usize is at least 32 bit") - offset;
478488

479489
debug_assert!(
480490
truncated_amount > 0,
@@ -504,7 +514,7 @@ impl BindData {
504514
// offset is zero here as we don't have a buffer yet
505515
// we know the requested length here so we can just request
506516
// the correct size
507-
let mut vec = vec![0_u8; self.length as usize];
517+
let mut vec = vec![0_u8; self.length.try_into().expect("usize is at least 32 bit")];
508518
self.capacity = vec.capacity();
509519
self.bytes = NonNull::new(vec.as_mut_ptr());
510520
mem::forget(vec);

diesel/src/mysql/connection/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ impl Connection for MysqlConnection {
186186
// we have not called result yet, so calling `execute` is
187187
// fine
188188
let stmt_use = unsafe { stmt.execute() }?;
189-
Ok(stmt_use.affected_rows())
189+
stmt_use.affected_rows()
190190
}),
191191
&mut self.transaction_state,
192192
&mut self.instrumentation,

diesel/src/mysql/connection/stmt/mod.rs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,11 @@ pub(super) struct StatementUse<'a> {
153153
}
154154

155155
impl<'a> StatementUse<'a> {
156-
pub(in crate::mysql::connection) fn affected_rows(&self) -> usize {
156+
pub(in crate::mysql::connection) fn affected_rows(&self) -> QueryResult<usize> {
157157
let affected_rows = unsafe { ffi::mysql_stmt_affected_rows(self.inner.stmt.as_ptr()) };
158-
affected_rows as usize
158+
affected_rows
159+
.try_into()
160+
.map_err(|e| Error::DeserializationError(Box::new(e)))
159161
}
160162

161163
/// This function should be called after `execute` only
@@ -167,14 +169,19 @@ impl<'a> StatementUse<'a> {
167169

168170
pub(super) fn populate_row_buffers(&self, binds: &mut OutputBinds) -> QueryResult<Option<()>> {
169171
let next_row_result = unsafe { ffi::mysql_stmt_fetch(self.inner.stmt.as_ptr()) };
170-
match next_row_result as libc::c_uint {
171-
ffi::MYSQL_NO_DATA => Ok(None),
172-
ffi::MYSQL_DATA_TRUNCATED => binds.populate_dynamic_buffers(self).map(Some),
173-
0 => {
174-
binds.update_buffer_lengths();
175-
Ok(Some(()))
172+
if next_row_result < 0 {
173+
self.inner.did_an_error_occur().map(Some)
174+
} else {
175+
#[allow(clippy::cast_sign_loss)] // that's how it's supposed to be based on the API
176+
match next_row_result as libc::c_uint {
177+
ffi::MYSQL_NO_DATA => Ok(None),
178+
ffi::MYSQL_DATA_TRUNCATED => binds.populate_dynamic_buffers(self).map(Some),
179+
0 => {
180+
binds.update_buffer_lengths();
181+
Ok(Some(()))
182+
}
183+
_error => self.inner.did_an_error_occur().map(Some),
176184
}
177-
_error => self.inner.did_an_error_occur().map(Some),
178185
}
179186
}
180187

@@ -187,7 +194,8 @@ impl<'a> StatementUse<'a> {
187194
ffi::mysql_stmt_fetch_column(
188195
self.inner.stmt.as_ptr(),
189196
bind,
190-
idx as libc::c_uint,
197+
idx.try_into()
198+
.map_err(|e| Error::DeserializationError(Box::new(e)))?,
191199
offset as libc::c_ulong,
192200
);
193201
self.inner.did_an_error_occur()

diesel/src/mysql/types/date_and_time/chrono.rs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ impl FromSql<Datetime, Mysql> for NaiveDateTime {
2626
impl ToSql<Timestamp, Mysql> for NaiveDateTime {
2727
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
2828
let mysql_time = MysqlTime {
29-
year: self.year() as libc::c_uint,
29+
year: self.year().try_into()?,
3030
month: self.month() as libc::c_uint,
3131
day: self.day() as libc::c_uint,
3232
hour: self.hour() as libc::c_uint,
@@ -48,16 +48,16 @@ impl FromSql<Timestamp, Mysql> for NaiveDateTime {
4848
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
4949
let mysql_time = <MysqlTime as FromSql<Timestamp, Mysql>>::from_sql(bytes)?;
5050

51-
NaiveDate::from_ymd_opt(mysql_time.year as i32, mysql_time.month, mysql_time.day)
52-
.and_then(|v| {
53-
v.and_hms_micro_opt(
54-
mysql_time.hour,
55-
mysql_time.minute,
56-
mysql_time.second,
57-
mysql_time.second_part as u32,
58-
)
59-
})
60-
.ok_or_else(|| format!("Cannot parse this date: {mysql_time:?}").into())
51+
let micro = mysql_time.second_part.try_into()?;
52+
NaiveDate::from_ymd_opt(
53+
mysql_time.year.try_into()?,
54+
mysql_time.month,
55+
mysql_time.day,
56+
)
57+
.and_then(|v| {
58+
v.and_hms_micro_opt(mysql_time.hour, mysql_time.minute, mysql_time.second, micro)
59+
})
60+
.ok_or_else(|| format!("Cannot parse this date: {mysql_time:?}").into())
6161
}
6262
}
6363

@@ -94,7 +94,7 @@ impl FromSql<Time, Mysql> for NaiveTime {
9494
impl ToSql<Date, Mysql> for NaiveDate {
9595
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
9696
let mysql_time = MysqlTime {
97-
year: self.year() as libc::c_uint,
97+
year: self.year().try_into()?,
9898
month: self.month() as libc::c_uint,
9999
day: self.day() as libc::c_uint,
100100
hour: 0,
@@ -114,8 +114,12 @@ impl ToSql<Date, Mysql> for NaiveDate {
114114
impl FromSql<Date, Mysql> for NaiveDate {
115115
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
116116
let mysql_time = <MysqlTime as FromSql<Date, Mysql>>::from_sql(bytes)?;
117-
NaiveDate::from_ymd_opt(mysql_time.year as i32, mysql_time.month, mysql_time.day)
118-
.ok_or_else(|| format!("Unable to convert {mysql_time:?} to chrono").into())
117+
NaiveDate::from_ymd_opt(
118+
mysql_time.year.try_into()?,
119+
mysql_time.month,
120+
mysql_time.day,
121+
)
122+
.ok_or_else(|| format!("Unable to convert {mysql_time:?} to chrono").into())
119123
}
120124
}
121125

diesel/src/mysql/types/date_and_time/time.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ fn to_time(dt: MysqlTime) -> Result<NaiveTime, Box<dyn std::error::Error>> {
1515
("year", dt.year),
1616
("month", dt.month),
1717
("day", dt.day),
18-
("offset", dt.time_zone_displacement as u32),
18+
("offset", dt.time_zone_displacement.try_into()?),
1919
] {
2020
if field != 0 {
2121
return Err(format!("Unable to convert {dt:?} to time: {name} must be 0").into());
@@ -63,7 +63,7 @@ fn to_primitive_datetime(dt: OffsetDateTime) -> PrimitiveDateTime {
6363
impl ToSql<Datetime, Mysql> for PrimitiveDateTime {
6464
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
6565
let mysql_time = MysqlTime {
66-
year: self.year() as libc::c_uint,
66+
year: self.year().try_into()?,
6767
month: self.month() as libc::c_uint,
6868
day: self.day() as libc::c_uint,
6969
hour: self.hour() as libc::c_uint,
@@ -171,7 +171,7 @@ impl FromSql<Time, Mysql> for NaiveTime {
171171
impl ToSql<Date, Mysql> for NaiveDate {
172172
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
173173
let mysql_time = MysqlTime {
174-
year: self.year() as libc::c_uint,
174+
year: self.year().try_into()?,
175175
month: self.month() as libc::c_uint,
176176
day: self.day() as libc::c_uint,
177177
hour: 0,

diesel/src/mysql/types/mod.rs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ impl ToSql<TinyInt, Mysql> for i8 {
2525
impl FromSql<TinyInt, Mysql> for i8 {
2626
fn from_sql(value: MysqlValue<'_>) -> deserialize::Result<Self> {
2727
let bytes = value.as_bytes();
28-
Ok(bytes[0] as i8)
28+
Ok(i8::from_be_bytes([bytes[0]]))
2929
}
3030
}
3131

@@ -69,12 +69,14 @@ where
6969
#[cfg(feature = "mysql_backend")]
7070
impl ToSql<Unsigned<TinyInt>, Mysql> for u8 {
7171
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
72-
ToSql::<TinyInt, Mysql>::to_sql(&(*self as i8), &mut out.reborrow())
72+
out.write_u8(*self)?;
73+
Ok(IsNull::No)
7374
}
7475
}
7576

7677
#[cfg(feature = "mysql_backend")]
7778
impl FromSql<Unsigned<TinyInt>, Mysql> for u8 {
79+
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)] // that's what we want
7880
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
7981
let signed: i8 = FromSql::<TinyInt, Mysql>::from_sql(bytes)?;
8082
Ok(signed as u8)
@@ -84,12 +86,18 @@ impl FromSql<Unsigned<TinyInt>, Mysql> for u8 {
8486
#[cfg(feature = "mysql_backend")]
8587
impl ToSql<Unsigned<SmallInt>, Mysql> for u16 {
8688
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
87-
ToSql::<SmallInt, Mysql>::to_sql(&(*self as i16), &mut out.reborrow())
89+
out.write_u16::<NativeEndian>(*self)?;
90+
Ok(IsNull::No)
8891
}
8992
}
9093

9194
#[cfg(feature = "mysql_backend")]
9295
impl FromSql<Unsigned<SmallInt>, Mysql> for u16 {
96+
#[allow(
97+
clippy::cast_possible_wrap,
98+
clippy::cast_sign_loss,
99+
clippy::cast_possible_truncation
100+
)] // that's what we want
93101
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
94102
let signed: i32 = FromSql::<Integer, Mysql>::from_sql(bytes)?;
95103
Ok(signed as u16)
@@ -99,12 +107,18 @@ impl FromSql<Unsigned<SmallInt>, Mysql> for u16 {
99107
#[cfg(feature = "mysql_backend")]
100108
impl ToSql<Unsigned<Integer>, Mysql> for u32 {
101109
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
102-
ToSql::<Integer, Mysql>::to_sql(&(*self as i32), &mut out.reborrow())
110+
out.write_u32::<NativeEndian>(*self)?;
111+
Ok(IsNull::No)
103112
}
104113
}
105114

106115
#[cfg(feature = "mysql_backend")]
107116
impl FromSql<Unsigned<Integer>, Mysql> for u32 {
117+
#[allow(
118+
clippy::cast_possible_wrap,
119+
clippy::cast_sign_loss,
120+
clippy::cast_possible_truncation
121+
)] // that's what we want
108122
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
109123
let signed: i64 = FromSql::<BigInt, Mysql>::from_sql(bytes)?;
110124
Ok(signed as u32)
@@ -114,12 +128,18 @@ impl FromSql<Unsigned<Integer>, Mysql> for u32 {
114128
#[cfg(feature = "mysql_backend")]
115129
impl ToSql<Unsigned<BigInt>, Mysql> for u64 {
116130
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
117-
ToSql::<BigInt, Mysql>::to_sql(&(*self as i64), &mut out.reborrow())
131+
out.write_u64::<NativeEndian>(*self)?;
132+
Ok(IsNull::No)
118133
}
119134
}
120135

121136
#[cfg(feature = "mysql_backend")]
122137
impl FromSql<Unsigned<BigInt>, Mysql> for u64 {
138+
#[allow(
139+
clippy::cast_possible_wrap,
140+
clippy::cast_sign_loss,
141+
clippy::cast_possible_truncation
142+
)] // that's what we want
123143
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
124144
let signed: i64 = FromSql::<BigInt, Mysql>::from_sql(bytes)?;
125145
Ok(signed as u64)

diesel/src/mysql/types/primitives.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ where
2222
}
2323
}
2424

25+
#[allow(clippy::cast_possible_truncation)] // that's what we want here
2526
fn f32_to_i64(f: f32) -> deserialize::Result<i64> {
2627
if f <= i64::MAX as f32 && f >= i64::MIN as f32 {
2728
Ok(f.trunc() as i64)
@@ -32,6 +33,7 @@ fn f32_to_i64(f: f32) -> deserialize::Result<i64> {
3233
}
3334
}
3435

36+
#[allow(clippy::cast_possible_truncation)] // that's what we want here
3537
fn f64_to_i64(f: f64) -> deserialize::Result<i64> {
3638
if f <= i64::MAX as f64 && f >= i64::MIN as f64 {
3739
Ok(f.trunc() as i64)
@@ -128,6 +130,8 @@ impl FromSql<Float, Mysql> for f32 {
128130
NumericRepresentation::Medium(x) => Ok(x as Self),
129131
NumericRepresentation::Big(x) => Ok(x as Self),
130132
NumericRepresentation::Float(x) => Ok(x),
133+
// there is currently no way to do this in a better way
134+
#[allow(clippy::cast_possible_truncation)]
131135
NumericRepresentation::Double(x) => Ok(x as Self),
132136
NumericRepresentation::Decimal(bytes) => Ok(str::from_utf8(bytes)?.parse()?),
133137
}

diesel/src/mysql/value.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl<'a> MysqlValue<'a> {
6060
pub(crate) fn numeric_value(&self) -> deserialize::Result<NumericRepresentation<'_>> {
6161
Ok(match self.tpe {
6262
MysqlType::UnsignedTiny | MysqlType::Tiny => {
63-
NumericRepresentation::Tiny(self.raw[0] as i8)
63+
NumericRepresentation::Tiny(self.raw[0].try_into()?)
6464
}
6565
MysqlType::UnsignedShort | MysqlType::Short => {
6666
NumericRepresentation::Small(i16::from_ne_bytes((&self.raw[..2]).try_into()?))

diesel/src/pg/connection/copy.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ impl<'conn> BufRead for CopyToBuffer<'conn> {
102102
let len =
103103
pq_sys::PQgetCopyData(self.conn.internal_connection.as_ptr(), &mut self.ptr, 0);
104104
match len {
105-
len if len >= 0 => self.len = len as usize + 1,
105+
len if len >= 0 => {
106+
self.len = 1 + usize::try_from(len)
107+
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
108+
}
106109
-1 => self.len = 0,
107110
_ => {
108111
let error = self.conn.last_error_message();

0 commit comments

Comments
 (0)