1
1
use std:: env;
2
2
use std:: io:: { self , Error , ErrorKind } ;
3
3
use std:: sync:: Arc ;
4
- use std:: time:: { SystemTime , UNIX_EPOCH } ;
4
+ use std:: time:: { SystemTime , UNIX_EPOCH , Duration } ;
5
+ use std:: collections:: HashSet ;
6
+ use std:: sync:: RwLock ;
5
7
6
8
use actix_cors:: Cors ;
7
9
use actix_web:: {
@@ -11,9 +13,13 @@ use actix_web::{
11
13
12
14
use serde:: Deserialize ;
13
15
use serde_json:: json;
16
+ use tokio:: time:: interval;
17
+ use reqwest;
14
18
15
19
use su:: domain:: { flows, init_deps, router, Deps , PromMetrics } ;
16
20
21
+ type IpWhitelist = Arc < RwLock < HashSet < String > > > ;
22
+
17
23
#[ derive( Deserialize ) ]
18
24
struct FromTo {
19
25
from : Option < String > ,
@@ -61,6 +67,40 @@ fn err_response(err: String) -> HttpResponse {
61
67
. body ( error_json. to_string ( ) )
62
68
}
63
69
70
+ fn get_client_ip ( req : & HttpRequest ) -> Option < String > {
71
+ req. connection_info ( )
72
+ . realip_remote_addr ( )
73
+ . map ( |ip| ip. to_string ( ) )
74
+ }
75
+
76
+ fn is_ip_allowed ( ip : & str , whitelist : & IpWhitelist ) -> bool {
77
+ println ! ( "Checking if IP {} is allowed" , ip) ;
78
+ match whitelist. read ( ) {
79
+ Ok ( ips) => ips. contains ( ip) ,
80
+ Err ( _) => false ,
81
+ }
82
+ }
83
+
84
+ async fn fetch_ip_whitelist ( url : String ) -> Result < HashSet < String > , reqwest:: Error > {
85
+ let response = reqwest:: get ( url) . await ?;
86
+ let ips: Vec < String > = response. json ( ) . await ?;
87
+ Ok ( ips. into_iter ( ) . collect ( ) )
88
+ }
89
+
90
+ async fn update_ip_whitelist_loop ( ip_whitelist : IpWhitelist , whitelist_url : String ) {
91
+ let mut interval = interval ( Duration :: from_secs ( 30 ) ) ;
92
+
93
+ loop {
94
+ interval. tick ( ) . await ;
95
+
96
+ if let Ok ( new_ips) = fetch_ip_whitelist ( whitelist_url. clone ( ) ) . await {
97
+ if let Ok ( mut ips) = ip_whitelist. write ( ) {
98
+ * ips = new_ips;
99
+ }
100
+ }
101
+ }
102
+ }
103
+
64
104
async fn base (
65
105
data : web:: Data < AppState > ,
66
106
query_params : web:: Query < ProcessId > ,
@@ -128,6 +168,16 @@ async fn main_post_route(
128
168
return HttpResponse :: ServiceUnavailable ( )
129
169
. json ( json ! ( { "error" : "Server is warming up. Please try again later." } ) ) ;
130
170
}
171
+
172
+ if let Some ( client_ip) = get_client_ip ( & req) {
173
+ if !is_ip_allowed ( & client_ip, & data. ip_whitelist ) {
174
+ return HttpResponse :: Forbidden ( )
175
+ . json ( json ! ( { "error" : "Access denied" } ) ) ;
176
+ }
177
+ } else {
178
+ return HttpResponse :: BadRequest ( )
179
+ . json ( json ! ( { "error" : "Access denied" } ) ) ;
180
+ }
131
181
match router:: redirect_data_item (
132
182
data. deps . clone ( ) ,
133
183
req_body. to_vec ( ) ,
@@ -281,6 +331,7 @@ struct AppState {
281
331
deps : Arc < Deps > ,
282
332
metrics : Arc < PromMetrics > ,
283
333
startup_time : u64 ,
334
+ ip_whitelist : IpWhitelist ,
284
335
}
285
336
286
337
#[ actix_web:: main]
@@ -311,13 +362,32 @@ async fn main() -> io::Result<()> {
311
362
. as_secs ( ) ;
312
363
313
364
let ( deps, metrics) = init_deps ( mode) . await ;
365
+
366
+ let run_deps = deps. clone ( ) ;
367
+
368
+ let ip_whitelist: IpWhitelist = Arc :: new ( RwLock :: new ( HashSet :: new ( ) ) ) ;
369
+
370
+ match fetch_ip_whitelist ( run_deps. config . ip_whitelist_url ( ) ) . await {
371
+ Ok ( initial_ips) => {
372
+ if let Ok ( mut ips) = ip_whitelist. write ( ) {
373
+ * ips = initial_ips;
374
+ }
375
+ }
376
+ Err ( _) => { }
377
+ }
378
+
314
379
let app_state = web:: Data :: new ( AppState {
315
380
deps,
316
381
metrics,
317
382
startup_time,
383
+ ip_whitelist : ip_whitelist. clone ( ) ,
318
384
} ) ;
319
385
320
- let run_deps = app_state. deps . clone ( ) ;
386
+ let whitelist_clone = ip_whitelist. clone ( ) ;
387
+ let whitelist_url = run_deps. config . ip_whitelist_url ( ) ;
388
+ tokio:: spawn ( async move {
389
+ update_ip_whitelist_loop ( whitelist_clone, whitelist_url) . await ;
390
+ } ) ;
321
391
322
392
if run_deps. config . mode ( ) == "router" {
323
393
match router:: init_schedulers ( run_deps. clone ( ) ) . await {
0 commit comments