@@ -269,17 +269,21 @@ func (n *Node) startRPC(services map[reflect.Type]Service) error {
269
269
n .stopInProc ()
270
270
return err
271
271
}
272
- if err := n .startHTTP (n .httpEndpoint , apis , n .config .HTTPModules , n .config .HTTPCors , n .config .HTTPVirtualHosts , n .config .HTTPTimeouts ); err != nil {
272
+ if err := n .startHTTP (n .httpEndpoint , apis , n .config .HTTPModules , n .config .HTTPCors , n .config .HTTPVirtualHosts , n .config .HTTPTimeouts , n . config . WSOrigins ); err != nil {
273
273
n .stopIPC ()
274
274
n .stopInProc ()
275
275
return err
276
276
}
277
- if err := n .startWS (n .wsEndpoint , apis , n .config .WSModules , n .config .WSOrigins , n .config .WSExposeAll ); err != nil {
278
- n .stopHTTP ()
279
- n .stopIPC ()
280
- n .stopInProc ()
281
- return err
277
+ // if endpoints are not the same, start separate servers
278
+ if n .httpEndpoint != n .wsEndpoint {
279
+ if err := n .startWS (n .wsEndpoint , apis , n .config .WSModules , n .config .WSOrigins , n .config .WSExposeAll ); err != nil {
280
+ n .stopHTTP ()
281
+ n .stopIPC ()
282
+ n .stopInProc ()
283
+ return err
284
+ }
282
285
}
286
+
283
287
// All API endpoints started successfully
284
288
n .rpcAPIs = apis
285
289
return nil
@@ -348,22 +352,36 @@ func (n *Node) stopIPC() {
348
352
}
349
353
350
354
// startHTTP initializes and starts the HTTP RPC endpoint.
351
- func (n * Node ) startHTTP (endpoint string , apis []rpc.API , modules []string , cors []string , vhosts []string , timeouts rpc.HTTPTimeouts ) error {
355
+ func (n * Node ) startHTTP (endpoint string , apis []rpc.API , modules []string , cors []string , vhosts []string , timeouts rpc.HTTPTimeouts , wsOrigins [] string ) error {
352
356
// Short circuit if the HTTP endpoint isn't being exposed
353
357
if endpoint == "" {
354
358
return nil
355
359
}
356
- listener , handler , err := rpc .StartHTTPEndpoint (endpoint , apis , modules , cors , vhosts , timeouts )
360
+ // register apis and create handler stack
361
+ srv := rpc .NewServer ()
362
+ err := RegisterApisFromWhitelist (apis , modules , srv , false )
363
+ if err != nil {
364
+ return err
365
+ }
366
+ handler := NewHTTPHandlerStack (srv , cors , vhosts , & timeouts )
367
+ // wrap handler in websocket handler only if websocket port is the same as http rpc
368
+ if n .httpEndpoint == n .wsEndpoint {
369
+ handler = NewWebsocketUpgradeHandler (handler , srv .WebsocketHandler (wsOrigins ))
370
+ }
371
+ listener , err := StartHTTPEndpoint (endpoint , timeouts , handler )
357
372
if err != nil {
358
373
return err
359
374
}
360
375
n .log .Info ("HTTP endpoint opened" , "url" , fmt .Sprintf ("http://%v/" , listener .Addr ()),
361
376
"cors" , strings .Join (cors , "," ),
362
377
"vhosts" , strings .Join (vhosts , "," ))
378
+ if n .httpEndpoint == n .wsEndpoint {
379
+ n .log .Info ("WebSocket endpoint opened" , "url" , fmt .Sprintf ("ws://%v" , listener .Addr ()))
380
+ }
363
381
// All listeners booted successfully
364
382
n .httpEndpoint = endpoint
365
383
n .httpListener = listener
366
- n .httpHandler = handler
384
+ n .httpHandler = srv
367
385
368
386
return nil
369
387
}
@@ -388,15 +406,22 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig
388
406
if endpoint == "" {
389
407
return nil
390
408
}
391
- listener , handler , err := rpc .StartWSEndpoint (endpoint , apis , modules , wsOrigins , exposeAll )
409
+
410
+ srv := rpc .NewServer ()
411
+ handler := srv .WebsocketHandler (wsOrigins )
412
+ err := RegisterApisFromWhitelist (apis , modules , srv , exposeAll )
413
+ if err != nil {
414
+ return err
415
+ }
416
+ listener , err := startWSEndpoint (endpoint , handler )
392
417
if err != nil {
393
418
return err
394
419
}
395
420
n .log .Info ("WebSocket endpoint opened" , "url" , fmt .Sprintf ("ws://%s" , listener .Addr ()))
396
421
// All listeners booted successfully
397
422
n .wsEndpoint = endpoint
398
423
n .wsListener = listener
399
- n .wsHandler = handler
424
+ n .wsHandler = srv
400
425
401
426
return nil
402
427
}
@@ -641,3 +666,25 @@ func (n *Node) apis() []rpc.API {
641
666
},
642
667
}
643
668
}
669
+
670
+ // RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules,
671
+ // and then registers all of the APIs exposed by the services.
672
+ func RegisterApisFromWhitelist (apis []rpc.API , modules []string , srv * rpc.Server , exposeAll bool ) error {
673
+ if bad , available := checkModuleAvailability (modules , apis ); len (bad ) > 0 {
674
+ log .Error ("Unavailable modules in HTTP API list" , "unavailable" , bad , "available" , available )
675
+ }
676
+ // Generate the whitelist based on the allowed modules
677
+ whitelist := make (map [string ]bool )
678
+ for _ , module := range modules {
679
+ whitelist [module ] = true
680
+ }
681
+ // Register all the APIs exposed by the services
682
+ for _ , api := range apis {
683
+ if exposeAll || whitelist [api .Namespace ] || (len (whitelist ) == 0 && api .Public ) {
684
+ if err := srv .RegisterName (api .Namespace , api .Service ); err != nil {
685
+ return err
686
+ }
687
+ }
688
+ }
689
+ return nil
690
+ }
0 commit comments