Skip to content

Commit 399e2d0

Browse files
authored
credentials/alts: Optimize Reads (Roll forward #8236) (#8271)
1 parent 4cedec4 commit 399e2d0

File tree

4 files changed

+95
-24
lines changed

4 files changed

+95
-24
lines changed

credentials/alts/internal/conn/common.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,10 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
5454
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
5555
// If the size field is not complete, return the provided buffer as
5656
// remaining buffer.
57-
if len(b) < MsgLenFieldSize {
57+
length, sufficientBytes := parseMessageLength(b)
58+
if !sufficientBytes {
5859
return nil, b, nil
5960
}
60-
msgLenField := b[:MsgLenFieldSize]
61-
length := binary.LittleEndian.Uint32(msgLenField)
6261
if length > maxLen {
6362
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
6463
}
@@ -68,3 +67,14 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
6867
}
6968
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
7069
}
70+
71+
// parseMessageLength returns the message length based on frame header. It also
72+
// returns a boolean indicating if the buffer contains sufficient bytes to parse
73+
// the length header. If there are insufficient bytes, (0, false) is returned.
74+
func parseMessageLength(b []byte) (uint32, bool) {
75+
if len(b) < MsgLenFieldSize {
76+
return 0, false
77+
}
78+
msgLenField := b[:MsgLenFieldSize]
79+
return binary.LittleEndian.Uint32(msgLenField), true
80+
}

credentials/alts/internal/conn/record.go

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ const (
6767
// The maximum write buffer size. This *must* be multiple of
6868
// altsRecordDefaultLength.
6969
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
70+
// The initial buffer used to read from the network.
71+
altsReadBufferInitialSize = 32 * 1024 // 32KiB
7072
)
7173

7274
var (
@@ -87,7 +89,7 @@ type conn struct {
8789
net.Conn
8890
crypto ALTSRecordCrypto
8991
// buf holds data that has been read from the connection and decrypted,
90-
// but has not yet been returned by Read.
92+
// but has not yet been returned by Read. It is a sub-slice of protected.
9193
buf []byte
9294
payloadLengthLimit int
9395
// protected holds data read from the network but have not yet been
@@ -115,21 +117,13 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
115117
}
116118
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
117119
payloadLengthLimit := altsRecordDefaultLength - overhead
118-
var protectedBuf []byte
119-
if protected == nil {
120-
// We pre-allocate protected to be of size
121-
// 2*altsRecordDefaultLength-1 during initialization. We only
122-
// read from the network into protected when protected does not
123-
// contain a complete frame, which is at most
124-
// altsRecordDefaultLength-1 (bytes). And we read at most
125-
// altsRecordDefaultLength (bytes) data into protected at one
126-
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
127-
// to buffer data read from the network.
128-
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
129-
} else {
130-
protectedBuf = make([]byte, len(protected))
131-
copy(protectedBuf, protected)
132-
}
120+
// We pre-allocate protected to be of size 32KB during initialization.
121+
// We increase the size of the buffer by the required amount if it can't
122+
// hold a complete encrypted record.
123+
protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected)))
124+
// Copy additional data from hanshaker service.
125+
copy(protectedBuf, protected)
126+
protectedBuf = protectedBuf[:len(protected)]
133127

134128
altsConn := &conn{
135129
Conn: c,
@@ -166,11 +160,26 @@ func (p *conn) Read(b []byte) (n int, err error) {
166160
// Check whether a complete frame has been received yet.
167161
for len(framedMsg) == 0 {
168162
if len(p.protected) == cap(p.protected) {
169-
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
170-
copy(tmp, p.protected)
171-
p.protected = tmp
163+
// We can parse the length header to know exactly how large
164+
// the buffer needs to be to hold the entire frame.
165+
length, didParse := parseMessageLength(p.protected)
166+
if !didParse {
167+
// The protected buffer is initialized with a capacity of
168+
// larger than 4B. It should always hold the message length
169+
// header.
170+
panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))
171+
}
172+
oldProtectedBuf := p.protected
173+
// The new buffer must be able to hold the message length header
174+
// and the entire message.
175+
requiredCapacity := int(length) + MsgLenFieldSize
176+
p.protected = make([]byte, requiredCapacity)
177+
// Copy the contents of the old buffer and set the length of the
178+
// new buffer to the number of bytes already read.
179+
copy(p.protected, oldProtectedBuf)
180+
p.protected = p.protected[:len(oldProtectedBuf)]
172181
}
173-
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
182+
n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
174183
if err != nil {
175184
return 0, err
176185
}
@@ -189,6 +198,15 @@ func (p *conn) Read(b []byte) (n int, err error) {
189198
}
190199
ciphertext := msg[msgTypeFieldSize:]
191200

201+
// Decrypt directly into the buffer, avoiding a copy from p.buf if
202+
// possible.
203+
if len(b) >= len(ciphertext) {
204+
dec, err := p.crypto.Decrypt(b[:0], ciphertext)
205+
if err != nil {
206+
return 0, err
207+
}
208+
return len(dec), nil
209+
}
192210
// Decrypt requires that if the dst and ciphertext alias, they
193211
// must alias exactly. Code here used to use msg[:0], but msg
194212
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than

credentials/alts/internal/conn/record_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"math"
2727
"net"
2828
"reflect"
29+
"strings"
2930
"testing"
3031

3132
core "google.golang.org/grpc/credentials/alts/internal"
@@ -188,6 +189,48 @@ func (s) TestLargeMsg(t *testing.T) {
188189
}
189190
}
190191

192+
// TestLargeRecord writes a very large ALTS record and verifies that the server
193+
// receives it correctly. The large ALTS record should cause the reader to
194+
// expand it's read buffer to hold the entire record and store the decrypted
195+
// message until the receiver reads all of the bytes.
196+
func (s) TestLargeRecord(t *testing.T) {
197+
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
198+
msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize))
199+
// Increase the size of ALTS records written by the client.
200+
clientConn.payloadLengthLimit = math.MaxInt32
201+
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
202+
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
203+
}
204+
rcvMsg := make([]byte, len(msg))
205+
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
206+
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
207+
}
208+
if !reflect.DeepEqual(msg, rcvMsg) {
209+
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
210+
}
211+
}
212+
213+
// BenchmarkLargeMessage measures the performance of ALTS conns for sending and
214+
// receiving a large message.
215+
func BenchmarkLargeMessage(b *testing.B) {
216+
msgLen := 20 * 1024 * 1024 // 20 MiB
217+
msg := make([]byte, msgLen)
218+
rcvMsg := make([]byte, len(msg))
219+
b.ResetTimer()
220+
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
221+
for range b.N {
222+
// Write 20 MiB 5 times to transfer a total of 100 MiB.
223+
for range 5 {
224+
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
225+
b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
226+
}
227+
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
228+
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
229+
}
230+
}
231+
}
232+
}
233+
191234
func testIncorrectMsgType(t *testing.T, rp string) {
192235
// framedMsg is an empty ciphertext with correct framing but wrong
193236
// message type.

credentials/alts/internal/handshaker/handshaker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
308308
// whatever received from the network and send it to the handshaker service.
309309
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
310310
var lastWriteTime time.Time
311+
buf := make([]byte, frameLimit)
311312
for {
312313
if len(resp.OutFrames) > 0 {
313314
lastWriteTime = time.Now()
@@ -318,7 +319,6 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
318319
if resp.Result != nil {
319320
return resp.Result, extra, nil
320321
}
321-
buf := make([]byte, frameLimit)
322322
n, err := h.conn.Read(buf)
323323
if err != nil && err != io.EOF {
324324
return nil, nil, err

0 commit comments

Comments
 (0)