Skip to content

Commit 041cd42

Browse files
authored
refactor: do not call use upon mysql connection (#818)
1 parent f907a93 commit 041cd42

File tree

3 files changed

+64
-13
lines changed

3 files changed

+64
-13
lines changed

src/servers/src/error.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ pub enum Error {
249249
#[snafu(backtrace)]
250250
source: common_grpc::error::Error,
251251
},
252+
253+
#[snafu(display("Cannot find requested database: {}-{}", catalog, schema))]
254+
DatabaseNotFound { catalog: String, schema: String },
252255
}
253256

254257
pub type Result<T> = std::result::Result<T, Error>;
@@ -306,6 +309,8 @@ impl ErrorExt for Error {
306309
| InvalidAuthorizationHeader { .. }
307310
| InvalidBase64Value { .. }
308311
| InvalidUtf8Value { .. } => StatusCode::InvalidAuthHeader,
312+
313+
DatabaseNotFound { .. } => StatusCode::DatabaseNotFound,
309314
}
310315
}
311316

src/servers/src/mysql/handler.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use std::sync::Arc;
1717
use std::time::Instant;
1818

1919
use async_trait::async_trait;
20+
use common_catalog::consts::DEFAULT_CATALOG_NAME;
2021
use common_query::Output;
2122
use common_telemetry::{error, trace};
2223
use opensrv_mysql::{
@@ -183,14 +184,21 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
183184
}
184185

185186
async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
186-
let query = format!("USE {}", database.trim());
187-
let output = self.do_query(&query).await.remove(0);
188-
if let Err(e) = output {
189-
w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes())
190-
.await
187+
// TODO(sunng87): set catalog
188+
if self
189+
.query_handler
190+
.is_valid_schema(DEFAULT_CATALOG_NAME, database)?
191+
{
192+
let context = self.session.context();
193+
// TODO(sunng87): set catalog
194+
context.set_current_schema(database);
195+
w.ok().await.map_err(|e| e.into())
191196
} else {
192-
w.ok().await
197+
error::DatabaseNotFoundSnafu {
198+
catalog: DEFAULT_CATALOG_NAME,
199+
schema: database,
200+
}
201+
.fail()
193202
}
194-
.map_err(|e| e.into())
195203
}
196204
}

src/servers/tests/mysql/mysql_server_test.rs

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use std::net::SocketAddr;
1616
use std::sync::Arc;
1717
use std::time::Duration;
1818

19+
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
1920
use common_recordbatch::RecordBatch;
2021
use common_runtime::Builder as RuntimeBuilder;
2122
use datatypes::schema::Schema;
@@ -91,7 +92,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
9192
for _ in 0..2 {
9293
join_handles.push(tokio::spawn(async move {
9394
for _ in 0..1000 {
94-
match create_connection(server_port, false).await {
95+
match create_connection(server_port, None, false).await {
9596
Ok(mut connection) => {
9697
let result: u32 = connection
9798
.query_first("SELECT uint32s FROM numbers LIMIT 1")
@@ -197,7 +198,39 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
197198
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
198199
let server_addr = mysql_server.start(listening).await.unwrap();
199200

200-
let r = create_connection(server_addr.port(), client_tls).await;
201+
let r = create_connection(server_addr.port(), None, client_tls).await;
202+
assert!(r.is_err());
203+
Ok(())
204+
}
205+
206+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
207+
async fn test_db_name() -> Result<()> {
208+
let server_tls = TlsOption::default();
209+
let client_tls = false;
210+
211+
#[allow(unused)]
212+
let TestingData {
213+
column_schemas,
214+
mysql_columns_def,
215+
columns,
216+
mysql_text_output_rows,
217+
} = all_datatype_testing_data();
218+
let schema = Arc::new(Schema::new(column_schemas.clone()));
219+
let recordbatch = RecordBatch::new(schema, columns).unwrap();
220+
let table = MemTable::new("all_datatypes", recordbatch);
221+
222+
let mysql_server = create_mysql_server(table, server_tls)?;
223+
224+
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
225+
let server_addr = mysql_server.start(listening).await.unwrap();
226+
227+
let r = create_connection(server_addr.port(), None, client_tls).await;
228+
assert!(r.is_ok());
229+
230+
let r = create_connection(server_addr.port(), Some(DEFAULT_SCHEMA_NAME), client_tls).await;
231+
assert!(r.is_ok());
232+
233+
let r = create_connection(server_addr.port(), Some("tomcat"), client_tls).await;
201234
assert!(r.is_err());
202235
Ok(())
203236
}
@@ -219,7 +252,7 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) ->
219252
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
220253
let server_addr = mysql_server.start(listening).await.unwrap();
221254

222-
let mut connection = create_connection(server_addr.port(), client_tls)
255+
let mut connection = create_connection(server_addr.port(), None, client_tls)
223256
.await
224257
.unwrap();
225258

@@ -261,7 +294,7 @@ async fn test_query_concurrently() -> Result<()> {
261294
join_handles.push(tokio::spawn(async move {
262295
let mut rand: StdRng = rand::SeedableRng::from_entropy();
263296

264-
let mut connection = create_connection(server_port, false).await.unwrap();
297+
let mut connection = create_connection(server_port, None, false).await.unwrap();
265298
for _ in 0..expect_executed_queries_per_worker {
266299
let expected: u32 = rand.gen_range(0..100);
267300
let result: u32 = connection
@@ -275,7 +308,7 @@ async fn test_query_concurrently() -> Result<()> {
275308

276309
let should_recreate_conn = expected == 1;
277310
if should_recreate_conn {
278-
connection = create_connection(server_port, false).await.unwrap();
311+
connection = create_connection(server_port, None, false).await.unwrap();
279312
}
280313
}
281314
expect_executed_queries_per_worker
@@ -289,12 +322,17 @@ async fn test_query_concurrently() -> Result<()> {
289322
Ok(())
290323
}
291324

292-
async fn create_connection(port: u16, ssl: bool) -> mysql_async::Result<mysql_async::Conn> {
325+
async fn create_connection(
326+
port: u16,
327+
db_name: Option<&str>,
328+
ssl: bool,
329+
) -> mysql_async::Result<mysql_async::Conn> {
293330
let mut opts = mysql_async::OptsBuilder::default()
294331
.ip_or_hostname("127.0.0.1")
295332
.tcp_port(port)
296333
.prefer_socket(false)
297334
.wait_timeout(Some(1000))
335+
.db_name(db_name)
298336
.user(Some("greptime".to_string()))
299337
.pass(Some("greptime".to_string()));
300338

0 commit comments

Comments
 (0)