Skip to content

Commit bd8ac9e

Browse files
committed
quic: fill out connection id handling
Add support for sending and receiving NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames. Keep the peer supplied with up to 4 connection IDs. Retire connection IDs as required by the peer. Support connection IDs provided in the preferred_address transport parameter. RFC 9000, Section 5.1. For golang/go#58547 Change-Id: I015a69b94c40a6396e9f117a92c88acaf83c594e Reviewed-on: https://go-review.googlesource.com/c/net/+/513440 TryBot-Result: Gopher Robot <gobot@golang.org> Run-TryBot: Damien Neil <dneil@google.com> Reviewed-by: Jonathan Amsterdam <jba@google.com>
1 parent 08001cc commit bd8ac9e

15 files changed

Lines changed: 998 additions & 75 deletions

internal/quic/conn.go

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ type connListener interface {
6969
type connTestHooks interface {
7070
nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any)
7171
handleTLSEvent(tls.QUICEvent)
72+
newConnID(seq int64) ([]byte, error)
7273
}
7374

7475
func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) {
@@ -90,12 +91,12 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.
9091
c.msgc = make(chan any, 1)
9192

9293
if c.side == clientSide {
93-
if err := c.connIDState.initClient(newRandomConnID); err != nil {
94+
if err := c.connIDState.initClient(c.newConnIDFunc()); err != nil {
9495
return nil, err
9596
}
96-
initialConnID = c.connIDState.dstConnID()
97+
initialConnID, _ = c.connIDState.dstConnID()
9798
} else {
98-
if err := c.connIDState.initServer(newRandomConnID, initialConnID); err != nil {
99+
if err := c.connIDState.initServer(c.newConnIDFunc(), initialConnID); err != nil {
99100
return nil, err
100101
}
101102
}
@@ -154,11 +155,27 @@ func (c *Conn) discardKeys(now time.Time, space numberSpace) {
154155
}
155156

156157
// receiveTransportParameters applies transport parameters sent by the peer.
157-
func (c *Conn) receiveTransportParameters(p transportParameters) {
158+
func (c *Conn) receiveTransportParameters(p transportParameters) error {
158159
c.peerAckDelayExponent = p.ackDelayExponent
159160
c.loss.setMaxAckDelay(p.maxAckDelay)
161+
if err := c.connIDState.setPeerActiveConnIDLimit(p.activeConnIDLimit, c.newConnIDFunc()); err != nil {
162+
return err
163+
}
164+
if p.preferredAddrConnID != nil {
165+
var (
166+
seq int64 = 1 // sequence number of this conn id is 1
167+
retirePriorTo int64 = 0 // retire nothing
168+
resetToken [16]byte
169+
)
170+
copy(resetToken[:], p.preferredAddrResetToken)
171+
if err := c.connIDState.handleNewConnID(seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil {
172+
return err
173+
}
174+
}
160175

161176
// TODO: Many more transport parameters to come.
177+
178+
return nil
162179
}
163180

164181
type timerEvent struct{}
@@ -295,3 +312,10 @@ func firstTime(a, b time.Time) time.Time {
295312
return b
296313
}
297314
}
315+
316+
func (c *Conn) newConnIDFunc() newConnIDFunc {
317+
if c.testHooks != nil {
318+
return c.testHooks.newConnID
319+
}
320+
return newRandomConnID
321+
}

internal/quic/conn_id.go

Lines changed: 231 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package quic
88

99
import (
10+
"bytes"
1011
"crypto/rand"
1112
)
1213

@@ -18,8 +19,16 @@ type connIDState struct {
1819
// Local IDs are usually issued by us, and remote IDs by the peer.
1920
// The exception is the transient destination connection ID sent in
2021
// a client's Initial packets, which is chosen by the client.
22+
//
23+
// These are []connID rather than []*connID to minimize allocations.
2124
local []connID
2225
remote []connID
26+
27+
nextLocalSeq int64
28+
retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer
29+
peerActiveConnIDLimit int64 // peer's active_connection_id_limit transport parameter
30+
31+
needSend bool
2332
}
2433

2534
// A connID is a connection ID and associated metadata.
@@ -32,23 +41,36 @@ type connID struct {
3241
//
3342
// For the transient destination ID in a client's Initial packet, this is -1.
3443
seq int64
44+
45+
// retired is set when the connection ID is retired.
46+
retired bool
47+
48+
// send is set when the connection ID's state needs to be sent to the peer.
49+
//
50+
// For local IDs, this indicates a new ID that should be sent
51+
// in a NEW_CONNECTION_ID frame.
52+
//
53+
// For remote IDs, this indicates a retired ID that should be sent
54+
// in a RETIRE_CONNECTION_ID frame.
55+
send sentVal
3556
}
3657

3758
func (s *connIDState) initClient(newID newConnIDFunc) error {
3859
// Client chooses its initial connection ID, and sends it
3960
// in the Source Connection ID field of the first Initial packet.
40-
locid, err := newID()
61+
locid, err := newID(0)
4162
if err != nil {
4263
return err
4364
}
4465
s.local = append(s.local, connID{
4566
seq: 0,
4667
cid: locid,
4768
})
69+
s.nextLocalSeq = 1
4870

4971
// Client chooses an initial, transient connection ID for the server,
5072
// and sends it in the Destination Connection ID field of the first Initial packet.
51-
remid, err := newID()
73+
remid, err := newID(-1)
5274
if err != nil {
5375
return err
5476
}
@@ -70,14 +92,15 @@ func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error {
7092

7193
// Server chooses a connection ID, and sends it in the Source Connection ID of
7294
// the response to the clent.
73-
locid, err := newID()
95+
locid, err := newID(0)
7496
if err != nil {
7597
return err
7698
}
7799
s.local = append(s.local, connID{
78100
seq: 0,
79101
cid: locid,
80102
})
103+
s.nextLocalSeq = 1
81104
return nil
82105
}
83106

@@ -91,8 +114,44 @@ func (s *connIDState) srcConnID() []byte {
91114
}
92115

93116
// dstConnID is the Destination Connection ID to use in a sent packet.
94-
func (s *connIDState) dstConnID() []byte {
95-
return s.remote[0].cid
117+
func (s *connIDState) dstConnID() (cid []byte, ok bool) {
118+
for i := range s.remote {
119+
if !s.remote[i].retired {
120+
return s.remote[i].cid, true
121+
}
122+
}
123+
return nil, false
124+
}
125+
126+
// setPeerActiveConnIDLimit sets the active_connection_id_limit
127+
// transport parameter received from the peer.
128+
func (s *connIDState) setPeerActiveConnIDLimit(lim int64, newID newConnIDFunc) error {
129+
s.peerActiveConnIDLimit = lim
130+
return s.issueLocalIDs(newID)
131+
}
132+
133+
func (s *connIDState) issueLocalIDs(newID newConnIDFunc) error {
134+
toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
135+
for i := range s.local {
136+
if s.local[i].seq != -1 && !s.local[i].retired {
137+
toIssue--
138+
}
139+
}
140+
for toIssue > 0 {
141+
cid, err := newID(s.nextLocalSeq)
142+
if err != nil {
143+
return err
144+
}
145+
s.local = append(s.local, connID{
146+
seq: s.nextLocalSeq,
147+
cid: cid,
148+
})
149+
s.local[len(s.local)-1].send.setUnsent()
150+
s.nextLocalSeq++
151+
s.needSend = true
152+
toIssue--
153+
}
154+
return nil
96155
}
97156

98157
// handlePacket updates the connection ID state during the handshake
@@ -128,19 +187,184 @@ func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID []
128187
}
129188
}
130189

190+
func (s *connIDState) handleNewConnID(seq, retire int64, cid []byte, resetToken [16]byte) error {
191+
if len(s.remote[0].cid) == 0 {
192+
// "An endpoint that is sending packets with a zero-length
193+
// Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID
194+
// frame as a connection error of type PROTOCOL_VIOLATION."
195+
// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6
196+
return localTransportError(errProtocolViolation)
197+
}
198+
199+
if retire > s.retireRemotePriorTo {
200+
s.retireRemotePriorTo = retire
201+
}
202+
203+
have := false // do we already have this connection ID?
204+
active := 0
205+
for i := range s.remote {
206+
rcid := &s.remote[i]
207+
if !rcid.retired && rcid.seq < s.retireRemotePriorTo {
208+
s.retireRemote(rcid)
209+
}
210+
if !rcid.retired {
211+
active++
212+
}
213+
if rcid.seq == seq {
214+
if !bytes.Equal(rcid.cid, cid) {
215+
return localTransportError(errProtocolViolation)
216+
}
217+
have = true // yes, we've seen this sequence number
218+
}
219+
}
220+
221+
if !have {
222+
// This is a new connection ID that we have not seen before.
223+
//
224+
// We could take steps to keep the list of remote connection IDs
225+
// sorted by sequence number, but there's no particular need
226+
// so we don't bother.
227+
s.remote = append(s.remote, connID{
228+
seq: seq,
229+
cid: cloneBytes(cid),
230+
})
231+
if seq < s.retireRemotePriorTo {
232+
// This ID was already retired by a previous NEW_CONNECTION_ID frame.
233+
s.retireRemote(&s.remote[len(s.remote)-1])
234+
} else {
235+
active++
236+
}
237+
}
238+
239+
if active > activeConnIDLimit {
240+
// Retired connection IDs (including newly-retired ones) do not count
241+
// against the limit.
242+
// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5
243+
return localTransportError(errConnectionIDLimit)
244+
}
245+
246+
// "An endpoint SHOULD limit the number of connection IDs it has retired locally
247+
// for which RETIRE_CONNECTION_ID frames have not yet been acknowledged."
248+
// https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6
249+
//
250+
// Set a limit of four times the active_connection_id_limit for
251+
// the total number of remote connection IDs we keep state for locally.
252+
if len(s.remote) > 4*activeConnIDLimit {
253+
return localTransportError(errConnectionIDLimit)
254+
}
255+
256+
return nil
257+
}
258+
259+
// retireRemote marks a remote connection ID as retired.
260+
func (s *connIDState) retireRemote(rcid *connID) {
261+
rcid.retired = true
262+
rcid.send.setUnsent()
263+
s.needSend = true
264+
}
265+
266+
func (s *connIDState) handleRetireConnID(seq int64, newID newConnIDFunc) error {
267+
if seq >= s.nextLocalSeq {
268+
return localTransportError(errProtocolViolation)
269+
}
270+
for i := range s.local {
271+
if s.local[i].seq == seq {
272+
s.local = append(s.local[:i], s.local[i+1:]...)
273+
break
274+
}
275+
}
276+
s.issueLocalIDs(newID)
277+
return nil
278+
}
279+
280+
func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) {
281+
for i := range s.local {
282+
if s.local[i].seq != seq {
283+
continue
284+
}
285+
s.local[i].send.ackOrLoss(pnum, fate)
286+
if fate != packetAcked {
287+
s.needSend = true
288+
}
289+
return
290+
}
291+
}
292+
293+
func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) {
294+
for i := 0; i < len(s.remote); i++ {
295+
if s.remote[i].seq != seq {
296+
continue
297+
}
298+
if fate == packetAcked {
299+
// We have retired this connection ID, and the peer has acked.
300+
// Discard its state completely.
301+
s.remote = append(s.remote[:i], s.remote[i+1:]...)
302+
} else {
303+
// RETIRE_CONNECTION_ID frame was lost, mark for retransmission.
304+
s.needSend = true
305+
s.remote[i].send.ackOrLoss(pnum, fate)
306+
}
307+
return
308+
}
309+
}
310+
311+
// appendFrames appends NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames
312+
// to the current packet.
313+
//
314+
// It returns true if no more frames need appending,
315+
// false if not everything fit in the current packet.
316+
func (s *connIDState) appendFrames(w *packetWriter, pnum packetNumber, pto bool) bool {
317+
if !s.needSend && !pto {
318+
// Fast path: We don't need to send anything.
319+
return true
320+
}
321+
retireBefore := int64(0)
322+
if s.local[0].seq != -1 {
323+
retireBefore = s.local[0].seq
324+
}
325+
for i := range s.local {
326+
if !s.local[i].send.shouldSendPTO(pto) {
327+
continue
328+
}
329+
if !w.appendNewConnectionIDFrame(
330+
s.local[i].seq,
331+
retireBefore,
332+
s.local[i].cid,
333+
[16]byte{}, // TODO: stateless reset token
334+
) {
335+
return false
336+
}
337+
s.local[i].send.setSent(pnum)
338+
}
339+
for i := range s.remote {
340+
if !s.remote[i].send.shouldSendPTO(pto) {
341+
continue
342+
}
343+
if !w.appendRetireConnectionIDFrame(s.remote[i].seq) {
344+
return false
345+
}
346+
s.remote[i].send.setSent(pnum)
347+
}
348+
s.needSend = false
349+
return true
350+
}
351+
131352
func cloneBytes(b []byte) []byte {
132353
n := make([]byte, len(b))
133354
copy(n, b)
134355
return n
135356
}
136357

137-
type newConnIDFunc func() ([]byte, error)
358+
type newConnIDFunc func(seq int64) ([]byte, error)
138359

139-
func newRandomConnID() ([]byte, error) {
360+
func newRandomConnID(_ int64) ([]byte, error) {
140361
// It is not necessary for connection IDs to be cryptographically secure,
141362
// but it doesn't hurt.
142363
id := make([]byte, connIDLen)
143364
if _, err := rand.Read(id); err != nil {
365+
// TODO: Surface this error as a metric or log event or something.
366+
// rand.Read really shouldn't ever fail, but if it does, we should
367+
// have a way to inform the user.
144368
return nil, err
145369
}
146370
return id, nil

0 commit comments

Comments
 (0)