Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/common/error/src/status_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ pub enum StatusCode {
AuthHeaderNotFound = 7003,
/// Invalid http authorization header
InvalidAuthHeader = 7004,
/// Illegal request to connect catalog-schema
AccessDenied = 7005,
// ====== End of auth related status code =====
}

Expand Down
1 change: 1 addition & 0 deletions src/datanode/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ impl Services {
mysql_io_runtime,
Default::default(),
None,
None,
))
}
};
Expand Down
5 changes: 1 addition & 4 deletions src/frontend/src/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use std::sync::Arc;

use meta_client::MetaClientOpts;
use serde::{Deserialize, Serialize};
use servers::auth::UserProviderRef;
use servers::http::HttpOptions;
use servers::Mode;
use snafu::prelude::*;
Expand Down Expand Up @@ -92,8 +91,6 @@ impl<T: FrontendInstance> Frontend<T> {
let instance = Arc::new(instance);

// TODO(sunng87): merge this into instance
let provider = self.plugins.get::<UserProviderRef>().cloned();

Services::start(&self.opts, instance, provider).await
Services::start(&self.opts, instance, self.plugins.clone()).await
}
}
14 changes: 12 additions & 2 deletions src/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;

use common_runtime::Builder as RuntimeBuilder;
use common_telemetry::info;
use servers::auth::UserProviderRef;
use servers::auth::{SchemaValidatorRef, UserProviderRef};
use servers::grpc::GrpcServer;
use servers::http::HttpServer;
use servers::mysql::server::MysqlServer;
Expand All @@ -34,19 +34,23 @@ use crate::frontend::FrontendOptions;
use crate::influxdb::InfluxdbOptions;
use crate::instance::FrontendInstance;
use crate::prometheus::PrometheusOptions;
use crate::Plugins;

pub(crate) struct Services;

impl Services {
pub(crate) async fn start<T>(
opts: &FrontendOptions,
instance: Arc<T>,
user_provider: Option<UserProviderRef>,
plugins: Arc<Plugins>,
) -> Result<()>
where
T: FrontendInstance,
{
info!("Starting frontend servers");
let user_provider = plugins.get::<UserProviderRef>().cloned();
let schema_validator = plugins.get::<SchemaValidatorRef>().cloned();

let grpc_server_and_addr = if let Some(opts) = &opts.grpc_options {
let grpc_addr = parse_addr(&opts.addr)?;

Expand Down Expand Up @@ -84,6 +88,7 @@ impl Services {
mysql_io_runtime,
opts.tls.clone(),
user_provider.clone(),
schema_validator.clone(),
);

Some((mysql_server, mysql_addr))
Expand All @@ -107,6 +112,7 @@ impl Services {
opts.tls.clone(),
pg_io_runtime,
user_provider.clone(),
schema_validator.clone(),
)) as Box<dyn Server>;

Some((pg_server, pg_addr))
Expand Down Expand Up @@ -143,6 +149,10 @@ impl Services {
http_server.set_user_provider(user_provider);
}

if let Some(schema_validator) = schema_validator {
http_server.set_schema_validator(schema_validator);
}

if opentsdb_server_and_addr.is_some() {
http_server.set_opentsdb_handler(instance.clone());
}
Expand Down
1 change: 1 addition & 0 deletions src/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ rustls-pemfile = "1.0"
schemars = "0.8"
serde.workspace = true
serde_json = "1.0"
serde_urlencoded = "0.7"
session = { path = "../session" }
sha1 = "0.10"
snafu = { version = "0.7", features = ["backtraces"] }
Expand Down
113 changes: 109 additions & 4 deletions src/servers/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

pub mod user_provider;

use std::sync::Arc;

use common_error::ext::BoxedError;
Expand All @@ -24,6 +22,8 @@ use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};

use crate::auth::user_provider::StaticUserProvider;

pub mod user_provider;

#[async_trait::async_trait]
pub trait UserProvider: Send + Sync {
fn name(&self) -> &str;
Expand Down Expand Up @@ -70,12 +70,26 @@ pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
}
}

/// [`SchemaValidator`] validates whether a connection request
/// from a certain user to a certain catalog/schema is legal.
/// This authorization is performed after a user is authenticated,
/// so that the user's [`UserInfo`] should be already stored in the session.
#[async_trait::async_trait]
pub trait SchemaValidator: Send + Sync {
async fn validate(&self, catalog: &str, schema: &str, user_info: &UserInfo) -> Result<()>;
}

pub type SchemaValidatorRef = Arc<dyn SchemaValidator>;

#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("Invalid config value: {}, {}", value, msg))]
InvalidConfig { value: String, msg: String },

#[snafu(display("Illegal runtime param: {}", msg))]
IllegalParam { msg: String },

#[snafu(display("IO error, source: {}", source))]
Io {
source: std::io::Error,
Expand All @@ -96,18 +110,32 @@ pub enum Error {

#[snafu(display("Username and password does not match, username: {}", username))]
UserPasswordMismatch { username: String },

#[snafu(display(
"User {} is not allowed to access catalog {} and schema {}",
username,
catalog,
schema
))]
AccessDenied {
catalog: String,
schema: String,
username: String,
},
}

impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::InvalidConfig { .. } => StatusCode::InvalidArguments,
Error::IllegalParam { .. } => StatusCode::InvalidArguments,
Error::Io { .. } => StatusCode::Internal,
Error::AuthBackend { .. } => StatusCode::Internal,

Error::UserNotFound { .. } => StatusCode::UserNotFound,
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
Error::AccessDenied { .. } => StatusCode::AccessDenied,
}
}

Expand All @@ -123,7 +151,7 @@ impl ErrorExt for Error {
pub type Result<T> = std::result::Result<T, Error>;

#[cfg(test)]
pub mod test {
pub mod test_mock_user_provider {
use super::{Identity, Password, UserInfo, UserProvider};

pub struct MockUserProvider {}
Expand Down Expand Up @@ -168,11 +196,64 @@ pub mod test {
}
}

#[cfg(test)]
pub mod test_mock_schema_validator {

use session::context::UserInfo;

use super::SchemaValidator;
use crate::auth::AccessDeniedSnafu;

pub struct MockSchemaValidator {
catalog: String,
schema: String,
username: String,
}

impl MockSchemaValidator {
pub fn new(catalog: &str, schema: &str, username: &str) -> Self {
Self {
catalog: catalog.to_string(),
schema: schema.to_string(),
username: username.to_string(),
}
}
}

#[async_trait::async_trait]
impl SchemaValidator for MockSchemaValidator {
async fn validate(
&self,
catalog: &str,
schema: &str,
user_info: &UserInfo,
) -> Result<(), super::Error> {
if catalog == self.catalog
&& schema == self.schema
&& user_info.username() == self.username
{
Ok(())
} else {
AccessDeniedSnafu {
catalog: catalog.to_string(),
schema: schema.to_string(),
username: user_info.username().to_string(),
}
.fail()
}
}
}
}

#[cfg(test)]
mod tests {
use super::test::MockUserProvider;
use session::context::UserInfo;

use super::test_mock_user_provider::MockUserProvider;
use super::{Identity, Password, UserProvider};
use crate::auth;
use crate::auth::test_mock_schema_validator::MockSchemaValidator;
use crate::auth::SchemaValidator;

#[tokio::test]
async fn test_auth_by_plain_text() {
Expand Down Expand Up @@ -225,4 +306,28 @@ mod tests {
auth::Error::UserPasswordMismatch { .. }
);
}

#[tokio::test]
async fn test_schema_validate() {
let validator = MockSchemaValidator::new("greptime", "public", "test_user");
let right_user = UserInfo::new("test_user");
let wrong_user = UserInfo::default();

// check catalog
let re = validator
.validate("greptime_wrong", "public", &right_user)
.await;
assert!(re.is_err());
// check schema
let re = validator
.validate("greptime", "public_wrong", &right_user)
.await;
assert!(re.is_err());
// check username
let re = validator.validate("greptime", "public", &wrong_user).await;
assert!(re.is_err());
// check ok
let re = validator.validate("greptime", "public", &right_user).await;
assert!(re.is_ok());
}
}
6 changes: 6 additions & 0 deletions src/servers/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,12 @@ impl From<std::io::Error> for Error {
}
}

impl From<auth::Error> for Error {
fn from(e: auth::Error) -> Self {
Error::Auth { source: e }
}
}

impl IntoResponse for Error {
fn into_response(self) -> Response {
let (status, error_message) = match self {
Expand Down
17 changes: 15 additions & 2 deletions src/servers/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use tower_http::trace::TraceLayer;

use self::authorize::HttpAuth;
use self::influxdb::influxdb_write;
use crate::auth::UserProviderRef;
use crate::auth::{SchemaValidatorRef, UserProviderRef};
use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu};
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::query_handler::{
Expand Down Expand Up @@ -104,6 +104,7 @@ pub struct HttpServer {
script_handler: Option<ScriptHandlerRef>,
shutdown_tx: Mutex<Option<Sender<()>>>,
user_provider: Option<UserProviderRef>,
schema_validator: Option<SchemaValidatorRef>,
Comment thread
MichaelScofield marked this conversation as resolved.
Outdated
}

#[derive(Clone, Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -358,6 +359,7 @@ impl HttpServer {
influxdb_handler: None,
prom_handler: None,
user_provider: None,
schema_validator: None,
script_handler: None,
shutdown_tx: Mutex::new(None),
}
Expand Down Expand Up @@ -403,6 +405,14 @@ impl HttpServer {
self.user_provider.get_or_insert(user_provider);
}

pub fn set_schema_validator(&mut self, schema_validator: SchemaValidatorRef) {
debug_assert!(
self.schema_validator.is_none(),
"Schema validator can be set only once!"
);
self.schema_validator.get_or_insert(schema_validator);
}

pub fn make_app(&self) -> Router {
let mut api = OpenApi {
info: Info {
Expand Down Expand Up @@ -465,7 +475,10 @@ impl HttpServer {
.layer(TimeoutLayer::new(self.options.timeout))
// custom layer
.layer(AsyncRequireAuthorizationLayer::new(
HttpAuth::<BoxBody>::new(self.user_provider.clone()),
HttpAuth::<BoxBody>::new(
self.user_provider.clone(),
self.schema_validator.clone(),
),
)),
)
}
Expand Down
Loading