From cc111cd86b6c9a3cf640c13715abb3685f09eafe Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Mon, 13 Feb 2023 15:31:25 -0700 Subject: [PATCH] Add a batch response size limit --- rpc/handler.go | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/rpc/handler.go b/rpc/handler.go index cd95a067f3..822fb985cf 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -19,6 +19,7 @@ package rpc import ( "context" "encoding/json" + "fmt" "reflect" "strconv" "strings" @@ -34,21 +35,20 @@ import ( // // The entry points for incoming messages are: // -// h.handleMsg(message) -// h.handleBatch(message) +// h.handleMsg(message) +// h.handleBatch(message) // // Outgoing calls use the requestOp struct. Register the request before sending it // on the connection: // -// op := &requestOp{ids: ...} -// h.addRequestOp(op) +// op := &requestOp{ids: ...} +// h.addRequestOp(op) // // Now send the request, then wait for the reply to be delivered through handleMsg: // -// if err := op.wait(...); err != nil { -// h.removeRequestOp(op) // timeout, etc. -// } -// +// if err := op.wait(...); err != nil { +// h.removeRequestOp(op) // timeout, etc. +// } type handler struct { reg *serviceRegistry unsubscribeCb *callback @@ -92,6 +92,8 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg * return h } +const maxBatchResponseSize int = 10_000_000 // 10MB + // handleBatch executes all messages in a batch and returns the responses. func (h *handler) handleBatch(msgs []*jsonrpcMessage) { // Emit error response for empty batches: @@ -114,10 +116,21 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { } // Process calls on a goroutine because they may block indefinitely: h.startCallProc(func(cp *callProc) { - answers := make([]*jsonrpcMessage, 0, len(msgs)) + answers := make([]json.RawMessage, 0, len(msgs)) + var totalSize int for _, msg := range calls { if answer := h.handleCallMsg(cp, msg); answer != nil { - answers = append(answers, answer) + serialized, err := json.Marshal(answer) + if err != nil { + h.conn.writeJSON(cp.ctx, errorMessage(&parseError{"error serializing response: " + err.Error()})) + return + } + totalSize += len(serialized) + if totalSize > maxBatchResponseSize { + h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{fmt.Sprintf("batch response exceeded limit of %v bytes", maxBatchResponseSize)})) + return + } + answers = append(answers, serialized) } } h.addSubscriptions(cp.notifiers)