diff --git a/CHANGELOG.md b/CHANGELOG.md index 54cc904773c1..1eaf29f6d3eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ ### Added +* `MysqlConnection::establish` is able to initiate SSL connection. The database URL should contain `ssl_mode` parameter with a value of the [MySQL client command option `--ssl-mode`](https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode) if desired. + * `Connection` and `SimpleConnection` traits are implemented for a broader range of `r2d2::PooledConnection` types when the `r2d2` feature is enabled. @@ -80,7 +82,7 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ * Support for `uuid` version < 0.7.0 has been removed. * Support for `bigdecimal` < 0.0.13 has been removed. * Support for `pq-sys` < 0.4.0 has been removed. -* Support for `mysqlclient-sys` < 0.2.0 has been removed. +* Support for `mysqlclient-sys` < 0.2.5 has been removed. * Support for `time` types has been removed. * Support for `chrono` < 0.4.19 has been removed. * The `NonNull` trait for sql types has been removed in favour of the new `SqlType` trait. diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index a5530b2fbebc..4041cacb36c7 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -17,7 +17,7 @@ byteorder = { version = "1.0", optional = true } chrono = { version = "0.4.19", optional = true, default-features = false, features = ["clock", "std"] } libc = { version = "0.2.0", optional = true } libsqlite3-sys = { version = ">=0.17.2, <0.24.0", optional = true, features = ["bundled_bindings"] } -mysqlclient-sys = { version = "0.2.0", optional = true } +mysqlclient-sys = { version = "0.2.5", optional = true } pq-sys = { version = "0.4.0", optional = true } quickcheck = { version = "1.0.3", optional = true } serde_json = { version = ">=0.8.0, <2.0", optional = true } diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 9b71dcb9e46c..f9e95592e216 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -56,6 +56,13 @@ impl Connection for MysqlConnection { type Backend = Mysql; type TransactionManager = AnsiTransactionManager; + /// Establishes a new connection to the MySQL database + /// `database_url` may be enhanced by GET parameters + /// `mysql://[user[:password]@]host/database_name[?unix_socket=socket-path&ssl_mode=SSL_MODE*]` + /// + /// * `unix_socket` excepts the path to the unix socket + /// * `ssl_mode` expects a value defined for MySQL client command option `--ssl-mode` + /// See fn establish(database_url: &str) -> ConnectionResult { use crate::result::ConnectionError::CouldntSetupConfiguration; diff --git a/diesel/src/mysql/connection/raw.rs b/diesel/src/mysql/connection/raw.rs index 4c29c930a872..0f1fae5a9831 100644 --- a/diesel/src/mysql/connection/raw.rs +++ b/diesel/src/mysql/connection/raw.rs @@ -48,6 +48,10 @@ impl RawConnection { let unix_socket = connection_options.unix_socket(); let client_flags = connection_options.client_flags(); + if let Some(ssl_mode) = connection_options.ssl_mode() { + self.set_ssl_mode(ssl_mode) + } + unsafe { // Make sure you don't use the fake one! ffi::mysql_real_connect( @@ -180,6 +184,19 @@ impl RawConnection { unsafe { ffi::mysql_next_result(self.0.as_ptr()) }; self.did_an_error_occur() } + + fn set_ssl_mode(&self, ssl_mode: mysqlclient_sys::mysql_ssl_mode) { + let v = ssl_mode as u32; + let v_ptr: *const u32 = &v; + let n = ptr::NonNull::new(v_ptr as *mut u32).expect("NonNull::new failed"); + unsafe { + mysqlclient_sys::mysql_options( + self.0.as_ptr(), + mysqlclient_sys::mysql_option::MYSQL_OPT_SSL_MODE, + n.as_ptr() as *const std::ffi::c_void, + ) + }; + } } impl Drop for RawConnection { diff --git a/diesel/src/mysql/connection/url.rs b/diesel/src/mysql/connection/url.rs index 897cc920494a..582b7cc2d28d 100644 --- a/diesel/src/mysql/connection/url.rs +++ b/diesel/src/mysql/connection/url.rs @@ -8,6 +8,8 @@ use std::ffi::{CStr, CString}; use crate::result::{ConnectionError, ConnectionResult}; +use mysqlclient_sys::mysql_ssl_mode; + bitflags::bitflags! { pub struct CapabilityFlags: u32 { const CLIENT_LONG_PASSWORD = 0x00000001; @@ -46,6 +48,7 @@ pub struct ConnectionOptions { port: Option, unix_socket: Option, client_flags: CapabilityFlags, + ssl_mode: Option, } impl ConnectionOptions { @@ -73,6 +76,24 @@ impl ConnectionOptions { _ => None, }; + let ssl_mode = match query_pairs.get("ssl_mode") { + Some(v) => { + let ssl_mode = match v.to_lowercase().as_str() { + "disabled" => mysql_ssl_mode::SSL_MODE_DISABLED, + "preferred" => mysql_ssl_mode::SSL_MODE_PREFERRED, + "required" => mysql_ssl_mode::SSL_MODE_REQUIRED, + "verify_ca" => mysql_ssl_mode::SSL_MODE_VERIFY_CA, + "verify_identity" => mysql_ssl_mode::SSL_MODE_VERIFY_IDENTITY, + _ => { + let msg = "unknown ssl_mode"; + return Err(ConnectionError::InvalidConnectionUrl(msg.into())); + } + }; + Some(ssl_mode) + } + _ => None, + }; + let host = match url.host() { Some(Host::Ipv6(host)) => Some(CString::new(host.to_string())?), Some(host) if host.to_string() == "localhost" && unix_socket != None => None, @@ -101,6 +122,7 @@ impl ConnectionOptions { port: url.port(), unix_socket: unix_socket, client_flags: client_flags, + ssl_mode: ssl_mode, }) } @@ -131,6 +153,10 @@ impl ConnectionOptions { pub fn client_flags(&self) -> CapabilityFlags { self.client_flags } + + pub fn ssl_mode(&self) -> Option { + self.ssl_mode + } } fn decode_into_cstring(s: &str) -> ConnectionResult { @@ -266,3 +292,29 @@ fn unix_socket_tests() { conn_opts.unix_socket.unwrap() ); } + +#[test] +fn ssl_mode() { + let ssl_mode = |url| ConnectionOptions::parse(url).unwrap().ssl_mode(); + assert_eq!(ssl_mode("mysql://localhost"), None); + assert_eq!( + ssl_mode("mysql://localhost?ssl_mode=disabled"), + Some(mysql_ssl_mode::SSL_MODE_DISABLED) + ); + assert_eq!( + ssl_mode("mysql://localhost?ssl_mode=PREFERRED"), + Some(mysql_ssl_mode::SSL_MODE_PREFERRED) + ); + assert_eq!( + ssl_mode("mysql://localhost?ssl_mode=required"), + Some(mysql_ssl_mode::SSL_MODE_REQUIRED) + ); + assert_eq!( + ssl_mode("mysql://localhost?ssl_mode=VERIFY_CA"), + Some(mysql_ssl_mode::SSL_MODE_VERIFY_CA) + ); + assert_eq!( + ssl_mode("mysql://localhost?ssl_mode=verify_identity"), + Some(mysql_ssl_mode::SSL_MODE_VERIFY_IDENTITY) + ); +}