Skip to content

Commit ed9081b

Browse files
arjan-balpurnesh42H
authored andcommitted
* Revert "credentials/alts: Add comments to clarify buffer sizing (grpc#8232)" This reverts commit be25d96. * Revert "credentials/alts: Optimize reads (grpc#8204)" This reverts commit b368379.
1 parent 69f23ae commit ed9081b

File tree

4 files changed

+24
-95
lines changed

4 files changed

+24
-95
lines changed

credentials/alts/internal/conn/common.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
5454
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
5555
// If the size field is not complete, return the provided buffer as
5656
// remaining buffer.
57-
length, sufficientBytes := parseMessageLength(b)
58-
if !sufficientBytes {
57+
if len(b) < MsgLenFieldSize {
5958
return nil, b, nil
6059
}
60+
msgLenField := b[:MsgLenFieldSize]
61+
length := binary.LittleEndian.Uint32(msgLenField)
6162
if length > maxLen {
6263
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
6364
}
@@ -67,14 +68,3 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
6768
}
6869
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
6970
}
70-
71-
// parseMessageLength returns the message length based on frame header. It also
72-
// returns a boolean indicating if the buffer contains sufficient bytes to parse
73-
// the length header. If there are insufficient bytes, (0, false) is returned.
74-
func parseMessageLength(b []byte) (uint32, bool) {
75-
if len(b) < MsgLenFieldSize {
76-
return 0, false
77-
}
78-
msgLenField := b[:MsgLenFieldSize]
79-
return binary.LittleEndian.Uint32(msgLenField), true
80-
}

credentials/alts/internal/conn/record.go

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ const (
6363
// The maximum write buffer size. This *must* be multiple of
6464
// altsRecordDefaultLength.
6565
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
66-
// The initial buffer used to read from the network.
67-
altsReadBufferInitialSize = 32 * 1024 // 32KiB
6866
)
6967

7068
var (
@@ -85,7 +83,7 @@ type conn struct {
8583
net.Conn
8684
crypto ALTSRecordCrypto
8785
// buf holds data that has been read from the connection and decrypted,
88-
// but has not yet been returned by Read. It is a sub-slice of protected.
86+
// but has not yet been returned by Read.
8987
buf []byte
9088
payloadLengthLimit int
9189
// protected holds data read from the network but have not yet been
@@ -113,13 +111,21 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
113111
}
114112
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
115113
payloadLengthLimit := altsRecordDefaultLength - overhead
116-
// We pre-allocate protected to be of size 32KB during initialization.
117-
// We increase the size of the buffer by the required amount if it can't
118-
// hold a complete encrypted record.
119-
protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected)))
120-
// Copy additional data from hanshaker service.
121-
copy(protectedBuf, protected)
122-
protectedBuf = protectedBuf[:len(protected)]
114+
var protectedBuf []byte
115+
if protected == nil {
116+
// We pre-allocate protected to be of size
117+
// 2*altsRecordDefaultLength-1 during initialization. We only
118+
// read from the network into protected when protected does not
119+
// contain a complete frame, which is at most
120+
// altsRecordDefaultLength-1 (bytes). And we read at most
121+
// altsRecordDefaultLength (bytes) data into protected at one
122+
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
123+
// to buffer data read from the network.
124+
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
125+
} else {
126+
protectedBuf = make([]byte, len(protected))
127+
copy(protectedBuf, protected)
128+
}
123129

124130
altsConn := &conn{
125131
Conn: c,
@@ -156,26 +162,11 @@ func (p *conn) Read(b []byte) (n int, err error) {
156162
// Check whether a complete frame has been received yet.
157163
for len(framedMsg) == 0 {
158164
if len(p.protected) == cap(p.protected) {
159-
// We can parse the length header to know exactly how large
160-
// the buffer needs to be to hold the entire frame.
161-
length, didParse := parseMessageLength(p.protected)
162-
if !didParse {
163-
// The protected buffer is initialized with a capacity of
164-
// larger than 4B. It should always hold the message length
165-
// header.
166-
panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))
167-
}
168-
oldProtectedBuf := p.protected
169-
// The new buffer must be able to hold the message length header
170-
// and the entire message.
171-
requiredCapacity := int(length) + MsgLenFieldSize
172-
p.protected = make([]byte, requiredCapacity)
173-
// Copy the contents of the old buffer and set the length of the
174-
// new buffer to the number of bytes already read.
175-
copy(p.protected, oldProtectedBuf)
176-
p.protected = p.protected[:len(oldProtectedBuf)]
165+
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
166+
copy(tmp, p.protected)
167+
p.protected = tmp
177168
}
178-
n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
169+
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
179170
if err != nil {
180171
return 0, err
181172
}
@@ -194,15 +185,6 @@ func (p *conn) Read(b []byte) (n int, err error) {
194185
}
195186
ciphertext := msg[msgTypeFieldSize:]
196187

197-
// Decrypt directly into the buffer, avoiding a copy from p.buf if
198-
// possible.
199-
if len(b) >= len(ciphertext) {
200-
dec, err := p.crypto.Decrypt(b[:0], ciphertext)
201-
if err != nil {
202-
return 0, err
203-
}
204-
return len(dec), nil
205-
}
206188
// Decrypt requires that if the dst and ciphertext alias, they
207189
// must alias exactly. Code here used to use msg[:0], but msg
208190
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than

credentials/alts/internal/conn/record_test.go

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import (
2626
"math"
2727
"net"
2828
"reflect"
29-
"strings"
3029
"testing"
3130

3231
core "google.golang.org/grpc/credentials/alts/internal"
@@ -189,48 +188,6 @@ func (s) TestLargeMsg(t *testing.T) {
189188
}
190189
}
191190

192-
// TestLargeRecord writes a very large ALTS record and verifies that the server
193-
// receives it correctly. The large ALTS record should cause the reader to
194-
// expand it's read buffer to hold the entire record and store the decrypted
195-
// message until the receiver reads all of the bytes.
196-
func (s) TestLargeRecord(t *testing.T) {
197-
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
198-
msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize))
199-
// Increase the size of ALTS records written by the client.
200-
clientConn.payloadLengthLimit = math.MaxInt32
201-
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
202-
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
203-
}
204-
rcvMsg := make([]byte, len(msg))
205-
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
206-
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
207-
}
208-
if !reflect.DeepEqual(msg, rcvMsg) {
209-
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
210-
}
211-
}
212-
213-
// BenchmarkLargeMessage measures the performance of ALTS conns for sending and
214-
// receiving a large message.
215-
func BenchmarkLargeMessage(b *testing.B) {
216-
msgLen := 20 * 1024 * 1024 // 20 MiB
217-
msg := make([]byte, msgLen)
218-
rcvMsg := make([]byte, len(msg))
219-
b.ResetTimer()
220-
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
221-
for range b.N {
222-
// Write 20 MiB 5 times to transfer a total of 100 MiB.
223-
for range 5 {
224-
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
225-
b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
226-
}
227-
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
228-
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
229-
}
230-
}
231-
}
232-
}
233-
234191
func testIncorrectMsgType(t *testing.T, rp string) {
235192
// framedMsg is an empty ciphertext with correct framing but wrong
236193
// message type.

credentials/alts/internal/handshaker/handshaker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
308308
// whatever received from the network and send it to the handshaker service.
309309
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
310310
var lastWriteTime time.Time
311-
buf := make([]byte, frameLimit)
312311
for {
313312
if len(resp.OutFrames) > 0 {
314313
lastWriteTime = time.Now()
@@ -319,6 +318,7 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
319318
if resp.Result != nil {
320319
return resp.Result, extra, nil
321320
}
321+
buf := make([]byte, frameLimit)
322322
n, err := h.conn.Read(buf)
323323
if err != nil && err != io.EOF {
324324
return nil, nil, err

0 commit comments

Comments
 (0)