Skip to content

Commit ae8ee01

Browse files
authored
feat: per-project parallelism (#533)
1 parent 5e604b4 commit ae8ee01

File tree

4 files changed

+109
-4
lines changed

4 files changed

+109
-4
lines changed

gateway/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ impl std::fmt::Display for Error {
119119

120120
impl StdError for Error {}
121121

122-
#[derive(Debug, sqlx::Type, Serialize, Clone, PartialEq, Eq)]
122+
#[derive(Debug, sqlx::Type, Serialize, Clone, PartialEq, Eq, Hash)]
123123
#[sqlx(transparent)]
124124
pub struct ProjectName(String);
125125

gateway/src/service.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ use crate::acme::CustomDomain;
2828
use crate::args::ContextArgs;
2929
use crate::auth::{Key, Permissions, ScopedUser, User};
3030
use crate::project::Project;
31-
use crate::task::TaskBuilder;
31+
use crate::task::{BoxedTask, TaskBuilder};
32+
use crate::worker::TaskRouter;
3233
use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectDetails, ProjectName};
3334

3435
pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations");
@@ -187,6 +188,7 @@ impl GatewayContextProvider {
187188
pub struct GatewayService {
188189
provider: GatewayContextProvider,
189190
db: SqlitePool,
191+
task_router: TaskRouter<BoxedTask>,
190192
}
191193

192194
impl GatewayService {
@@ -201,7 +203,13 @@ impl GatewayService {
201203

202204
let provider = GatewayContextProvider::new(docker, container_settings);
203205

204-
Self { provider, db }
206+
let task_router = TaskRouter::new();
207+
208+
Self {
209+
provider,
210+
db,
211+
task_router,
212+
}
205213
}
206214

207215
pub async fn route(
@@ -547,6 +555,10 @@ impl GatewayService {
547555
pub fn new_task(self: &Arc<Self>) -> TaskBuilder {
548556
TaskBuilder::new(self.clone())
549557
}
558+
559+
pub fn task_router(&self) -> TaskRouter<BoxedTask> {
560+
self.task_router.clone()
561+
}
550562
}
551563

552564
#[derive(Clone)]

gateway/src/task.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use uuid::Uuid;
1212

1313
use crate::project::*;
1414
use crate::service::{GatewayContext, GatewayService};
15+
use crate::worker::TaskRouter;
1516
use crate::{AccountName, EndState, Error, ErrorKind, ProjectName, Refresh, State};
1617

1718
// Default maximum _total_ time a task is allowed to run
@@ -199,14 +200,51 @@ impl TaskBuilder {
199200
}
200201

201202
pub async fn send(self, sender: &Sender<BoxedTask>) -> Result<TaskHandle, Error> {
203+
let project_name = self.project_name.clone().expect("project_name is required");
204+
let task_router = self.service.task_router();
202205
let (task, handle) = AndThenNotify::after(self.build());
206+
let task = Route::<BoxedTask>::to(project_name, Box::new(task), task_router);
203207
match timeout(TASK_SEND_TIMEOUT, sender.send(Box::new(task))).await {
204208
Ok(Ok(_)) => Ok(handle),
205209
_ => Err(Error::from_kind(ErrorKind::ServiceUnavailable)),
206210
}
207211
}
208212
}
209213

214+
pub struct Route<T> {
215+
project_name: ProjectName,
216+
inner: Option<T>,
217+
router: TaskRouter<T>,
218+
}
219+
220+
impl<T> Route<T> {
221+
pub fn to(project_name: ProjectName, what: T, router: TaskRouter<T>) -> Self {
222+
Self {
223+
project_name,
224+
inner: Some(what),
225+
router,
226+
}
227+
}
228+
}
229+
230+
#[async_trait]
231+
impl Task<()> for Route<BoxedTask> {
232+
type Output = ();
233+
234+
type Error = Error;
235+
236+
async fn poll(&mut self, _ctx: ()) -> TaskResult<Self::Output, Self::Error> {
237+
if let Some(task) = self.inner.take() {
238+
match self.router.route(&self.project_name, task).await {
239+
Ok(_) => TaskResult::Done(()),
240+
Err(_) => TaskResult::Err(Error::from_kind(ErrorKind::Internal)),
241+
}
242+
} else {
243+
TaskResult::Done(())
244+
}
245+
}
246+
}
247+
210248
pub struct RunFn<F, O> {
211249
f: F,
212250
_output: PhantomData<O>,

gateway/src/worker.rs

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
use std::collections::HashMap;
2+
use std::sync::Arc;
3+
4+
use tokio::sync::mpsc::error::SendError;
15
use tokio::sync::mpsc::{channel, Receiver, Sender};
6+
use tokio::sync::RwLock;
27
use tracing::{debug, info};
38

49
use crate::task::{BoxedTask, TaskResult};
5-
use crate::Error;
10+
use crate::{Error, ProjectName};
611

712
pub const WORKER_QUEUE_SIZE: usize = 2048;
813

@@ -71,3 +76,53 @@ impl Worker<BoxedTask> {
7176
Ok(self)
7277
}
7378
}
79+
80+
pub struct TaskRouter<W> {
81+
table: Arc<RwLock<HashMap<ProjectName, Sender<W>>>>,
82+
}
83+
84+
impl<W> Clone for TaskRouter<W> {
85+
fn clone(&self) -> Self {
86+
Self {
87+
table: self.table.clone(),
88+
}
89+
}
90+
}
91+
92+
impl<W> Default for TaskRouter<W> {
93+
fn default() -> Self {
94+
Self::new()
95+
}
96+
}
97+
98+
impl<W> TaskRouter<W> {
99+
pub fn new() -> Self {
100+
Self {
101+
table: Arc::new(RwLock::new(HashMap::new())),
102+
}
103+
}
104+
}
105+
106+
impl TaskRouter<BoxedTask> {
107+
pub async fn route(
108+
&self,
109+
name: &ProjectName,
110+
task: BoxedTask,
111+
) -> Result<(), SendError<BoxedTask>> {
112+
let mut table = self.table.write().await;
113+
if let Some(sender) = table.get(name) {
114+
sender.send(task).await
115+
} else {
116+
let worker = Worker::new();
117+
let sender = worker.sender();
118+
119+
tokio::spawn(worker.start());
120+
121+
let res = sender.send(task).await;
122+
123+
table.insert(name.clone(), sender);
124+
125+
res
126+
}
127+
}
128+
}

0 commit comments

Comments
 (0)