Skip to content

Commit 30218b6

Browse files
committed
node: allow WebSocket and HTTP work on same port (ethereum#20810)
1 parent e9e383f commit 30218b6

File tree

9 files changed

+424
-205
lines changed

9 files changed

+424
-205
lines changed

node/api.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis
158158
}
159159
}
160160

161-
if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts); err != nil {
161+
if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts, api.node.config.WSOrigins); err != nil {
162162
return false, err
163163
}
164164
return true, nil

node/endpoints.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright 2018 The go-ethereum Authors
2+
// This file is part of the go-ethereum library.
3+
//
4+
// The go-ethereum library is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Lesser General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// The go-ethereum library is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Lesser General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Lesser General Public License
15+
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16+
17+
package node
18+
19+
import (
20+
"net"
21+
"net/http"
22+
"time"
23+
24+
"github.com/XinFinOrg/XDPoSChain/log"
25+
"github.com/XinFinOrg/XDPoSChain/rpc"
26+
)
27+
28+
// StartHTTPEndpoint starts the HTTP RPC endpoint.
29+
func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (net.Listener, error) {
30+
// start the HTTP listener
31+
var (
32+
listener net.Listener
33+
err error
34+
)
35+
if listener, err = net.Listen("tcp", endpoint); err != nil {
36+
return nil, err
37+
}
38+
// Bundle and start the HTTP server
39+
httpSrv := &http.Server{
40+
Handler: handler,
41+
ReadTimeout: timeouts.ReadTimeout,
42+
WriteTimeout: timeouts.WriteTimeout,
43+
IdleTimeout: timeouts.IdleTimeout,
44+
}
45+
log.Info("StartHTTPEndpoint", "ReadTimeout", timeouts.ReadTimeout, "WriteTimeout", timeouts.WriteTimeout, "IdleTimeout", timeouts.IdleTimeout)
46+
go httpSrv.Serve(listener)
47+
return listener, err
48+
}
49+
50+
// startWSEndpoint starts a websocket endpoint.
51+
func startWSEndpoint(endpoint string, handler http.Handler) (net.Listener, error) {
52+
// start the HTTP listener
53+
var (
54+
listener net.Listener
55+
err error
56+
)
57+
if listener, err = net.Listen("tcp", endpoint); err != nil {
58+
return nil, err
59+
}
60+
wsSrv := &http.Server{Handler: handler}
61+
go wsSrv.Serve(listener)
62+
return listener, err
63+
}
64+
65+
// checkModuleAvailability checks that all names given in modules are actually
66+
// available API services. It assumes that the MetadataApi module ("rpc") is always available;
67+
// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints.
68+
func checkModuleAvailability(modules []string, apis []rpc.API) (bad, available []string) {
69+
availableSet := make(map[string]struct{})
70+
for _, api := range apis {
71+
if _, ok := availableSet[api.Namespace]; !ok {
72+
availableSet[api.Namespace] = struct{}{}
73+
available = append(available, api.Namespace)
74+
}
75+
}
76+
for _, name := range modules {
77+
if _, ok := availableSet[name]; !ok && name != rpc.MetadataApi {
78+
bad = append(bad, name)
79+
}
80+
}
81+
return bad, available
82+
}
83+
84+
// CheckTimeouts ensures that timeout values are meaningful
85+
func CheckTimeouts(timeouts *rpc.HTTPTimeouts) {
86+
if timeouts.ReadTimeout < time.Second {
87+
log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", rpc.DefaultHTTPTimeouts.ReadTimeout)
88+
timeouts.ReadTimeout = rpc.DefaultHTTPTimeouts.ReadTimeout
89+
}
90+
if timeouts.WriteTimeout < time.Second {
91+
log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", rpc.DefaultHTTPTimeouts.WriteTimeout)
92+
timeouts.WriteTimeout = rpc.DefaultHTTPTimeouts.WriteTimeout
93+
}
94+
if timeouts.IdleTimeout < time.Second {
95+
log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", rpc.DefaultHTTPTimeouts.IdleTimeout)
96+
timeouts.IdleTimeout = rpc.DefaultHTTPTimeouts.IdleTimeout
97+
}
98+
}

node/node.go

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -269,17 +269,21 @@ func (n *Node) startRPC(services map[reflect.Type]Service) error {
269269
n.stopInProc()
270270
return err
271271
}
272-
if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts); err != nil {
272+
if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts, n.config.WSOrigins); err != nil {
273273
n.stopIPC()
274274
n.stopInProc()
275275
return err
276276
}
277-
if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil {
278-
n.stopHTTP()
279-
n.stopIPC()
280-
n.stopInProc()
281-
return err
277+
// if endpoints are not the same, start separate servers
278+
if n.httpEndpoint != n.wsEndpoint {
279+
if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil {
280+
n.stopHTTP()
281+
n.stopIPC()
282+
n.stopInProc()
283+
return err
284+
}
282285
}
286+
283287
// All API endpoints started successfully
284288
n.rpcAPIs = apis
285289
return nil
@@ -348,22 +352,36 @@ func (n *Node) stopIPC() {
348352
}
349353

350354
// startHTTP initializes and starts the HTTP RPC endpoint.
351-
func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts) error {
355+
func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts, wsOrigins []string) error {
352356
// Short circuit if the HTTP endpoint isn't being exposed
353357
if endpoint == "" {
354358
return nil
355359
}
356-
listener, handler, err := rpc.StartHTTPEndpoint(endpoint, apis, modules, cors, vhosts, timeouts)
360+
// register apis and create handler stack
361+
srv := rpc.NewServer()
362+
err := RegisterApisFromWhitelist(apis, modules, srv, false)
363+
if err != nil {
364+
return err
365+
}
366+
handler := NewHTTPHandlerStack(srv, cors, vhosts, &timeouts)
367+
// wrap handler in websocket handler only if websocket port is the same as http rpc
368+
if n.httpEndpoint == n.wsEndpoint {
369+
handler = NewWebsocketUpgradeHandler(handler, srv.WebsocketHandler(wsOrigins))
370+
}
371+
listener, err := StartHTTPEndpoint(endpoint, timeouts, handler)
357372
if err != nil {
358373
return err
359374
}
360375
n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", listener.Addr()),
361376
"cors", strings.Join(cors, ","),
362377
"vhosts", strings.Join(vhosts, ","))
378+
if n.httpEndpoint == n.wsEndpoint {
379+
n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", listener.Addr()))
380+
}
363381
// All listeners booted successfully
364382
n.httpEndpoint = endpoint
365383
n.httpListener = listener
366-
n.httpHandler = handler
384+
n.httpHandler = srv
367385

368386
return nil
369387
}
@@ -388,15 +406,22 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig
388406
if endpoint == "" {
389407
return nil
390408
}
391-
listener, handler, err := rpc.StartWSEndpoint(endpoint, apis, modules, wsOrigins, exposeAll)
409+
410+
srv := rpc.NewServer()
411+
handler := srv.WebsocketHandler(wsOrigins)
412+
err := RegisterApisFromWhitelist(apis, modules, srv, exposeAll)
413+
if err != nil {
414+
return err
415+
}
416+
listener, err := startWSEndpoint(endpoint, handler)
392417
if err != nil {
393418
return err
394419
}
395420
n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%s", listener.Addr()))
396421
// All listeners booted successfully
397422
n.wsEndpoint = endpoint
398423
n.wsListener = listener
399-
n.wsHandler = handler
424+
n.wsHandler = srv
400425

401426
return nil
402427
}
@@ -641,3 +666,25 @@ func (n *Node) apis() []rpc.API {
641666
},
642667
}
643668
}
669+
670+
// RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules,
671+
// and then registers all of the APIs exposed by the services.
672+
func RegisterApisFromWhitelist(apis []rpc.API, modules []string, srv *rpc.Server, exposeAll bool) error {
673+
if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
674+
log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available)
675+
}
676+
// Generate the whitelist based on the allowed modules
677+
whitelist := make(map[string]bool)
678+
for _, module := range modules {
679+
whitelist[module] = true
680+
}
681+
// Register all the APIs exposed by the services
682+
for _, api := range apis {
683+
if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
684+
if err := srv.RegisterName(api.Namespace, api.Service); err != nil {
685+
return err
686+
}
687+
}
688+
}
689+
return nil
690+
}

node/node_test.go

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package node
1818

1919
import (
2020
"errors"
21+
"net/http"
2122
"os"
2223
"reflect"
2324
"testing"
@@ -26,6 +27,7 @@ import (
2627
"github.com/XinFinOrg/XDPoSChain/crypto"
2728
"github.com/XinFinOrg/XDPoSChain/p2p"
2829
"github.com/XinFinOrg/XDPoSChain/rpc"
30+
"github.com/stretchr/testify/assert"
2931
)
3032

3133
var (
@@ -332,7 +334,7 @@ func TestServiceStartupAbortion(t *testing.T) {
332334
}
333335

334336
// Tests that even if a registered service fails to shut down cleanly, it does
335-
// not influece the rest of the shutdown invocations.
337+
// not influence the rest of the shutdown invocations.
336338
func TestServiceTerminationGuarantee(t *testing.T) {
337339
stack, err := New(testNodeConfig())
338340
if err != nil {
@@ -572,3 +574,58 @@ func TestAPIGather(t *testing.T) {
572574
}
573575
}
574576
}
577+
578+
func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) {
579+
node := startHTTP(t)
580+
defer node.stopHTTP()
581+
582+
wsReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
583+
if err != nil {
584+
t.Error("could not issue new http request ", err)
585+
}
586+
wsReq.Header.Set("Connection", "upgrade")
587+
wsReq.Header.Set("Upgrade", "websocket")
588+
wsReq.Header.Set("Sec-WebSocket-Version", "13")
589+
wsReq.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==")
590+
591+
resp := doHTTPRequest(t, wsReq)
592+
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
593+
}
594+
595+
func TestWebsocketHTTPOnSamePort_HTTPRequest(t *testing.T) {
596+
node := startHTTP(t)
597+
defer node.stopHTTP()
598+
599+
httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
600+
if err != nil {
601+
t.Error("could not issue new http request ", err)
602+
}
603+
httpReq.Header.Set("Accept-Encoding", "gzip")
604+
605+
resp := doHTTPRequest(t, httpReq)
606+
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
607+
}
608+
609+
func startHTTP(t *testing.T) *Node {
610+
conf := &Config{HTTPPort: 7453, WSPort: 7453}
611+
node, err := New(conf)
612+
if err != nil {
613+
t.Error("could not create a new node ", err)
614+
}
615+
616+
err = node.startHTTP("127.0.0.1:7453", []rpc.API{}, []string{}, []string{}, []string{}, rpc.HTTPTimeouts{}, []string{})
617+
if err != nil {
618+
t.Error("could not start http service on node ", err)
619+
}
620+
621+
return node
622+
}
623+
624+
func doHTTPRequest(t *testing.T, req *http.Request) *http.Response {
625+
client := &http.Client{}
626+
resp, err := client.Do(req)
627+
if err != nil {
628+
t.Error("could not issue a GET request to the given endpoint", err)
629+
}
630+
return resp
631+
}

0 commit comments

Comments
 (0)