Skip to content

Commit c54b5d0

Browse files
committed
gob: make recursive map and slice types work.
Before this fix, types such as type T map[string]T caused infinite recursion in the gob implementation. Now they just work. Fixes #1518. R=rsc CC=golang-dev https://golang.org/cl/4230045
1 parent 8956317 commit c54b5d0

File tree

5 files changed

+156
-76
lines changed

5 files changed

+156
-76
lines changed

Diff for: src/pkg/gob/codec_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ func TestScalarDecInstructions(t *testing.T) {
342342
var data struct {
343343
a int
344344
}
345-
instr := &decInstr{decOpMap[reflect.Int], 6, 0, 0, ovfl}
345+
instr := &decInstr{decOpTable[reflect.Int], 6, 0, 0, ovfl}
346346
state := newDecodeStateFromData(signedResult)
347347
execDec("int", instr, state, t, unsafe.Pointer(&data))
348348
if data.a != 17 {
@@ -355,7 +355,7 @@ func TestScalarDecInstructions(t *testing.T) {
355355
var data struct {
356356
a uint
357357
}
358-
instr := &decInstr{decOpMap[reflect.Uint], 6, 0, 0, ovfl}
358+
instr := &decInstr{decOpTable[reflect.Uint], 6, 0, 0, ovfl}
359359
state := newDecodeStateFromData(unsignedResult)
360360
execDec("uint", instr, state, t, unsafe.Pointer(&data))
361361
if data.a != 17 {
@@ -446,7 +446,7 @@ func TestScalarDecInstructions(t *testing.T) {
446446
var data struct {
447447
a uintptr
448448
}
449-
instr := &decInstr{decOpMap[reflect.Uintptr], 6, 0, 0, ovfl}
449+
instr := &decInstr{decOpTable[reflect.Uintptr], 6, 0, 0, ovfl}
450450
state := newDecodeStateFromData(unsignedResult)
451451
execDec("uintptr", instr, state, t, unsafe.Pointer(&data))
452452
if data.a != 17 {
@@ -511,7 +511,7 @@ func TestScalarDecInstructions(t *testing.T) {
511511
var data struct {
512512
a complex64
513513
}
514-
instr := &decInstr{decOpMap[reflect.Complex64], 6, 0, 0, ovfl}
514+
instr := &decInstr{decOpTable[reflect.Complex64], 6, 0, 0, ovfl}
515515
state := newDecodeStateFromData(complexResult)
516516
execDec("complex", instr, state, t, unsafe.Pointer(&data))
517517
if data.a != 17+19i {
@@ -524,7 +524,7 @@ func TestScalarDecInstructions(t *testing.T) {
524524
var data struct {
525525
a complex128
526526
}
527-
instr := &decInstr{decOpMap[reflect.Complex128], 6, 0, 0, ovfl}
527+
instr := &decInstr{decOpTable[reflect.Complex128], 6, 0, 0, ovfl}
528528
state := newDecodeStateFromData(complexResult)
529529
execDec("complex", instr, state, t, unsafe.Pointer(&data))
530530
if data.a != 17+19i {

Diff for: src/pkg/gob/decode.go

+37-26
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
671671
}
672672

673673
// Index by Go types.
674-
var decOpMap = []decOp{
674+
var decOpTable = [...]decOp{
675675
reflect.Bool: decBool,
676676
reflect.Int8: decInt8,
677677
reflect.Int16: decInt16,
@@ -701,37 +701,43 @@ var decIgnoreOpMap = map[typeId]decOp{
701701

702702
// Return the decoding op for the base type under rt and
703703
// the indirection count to reach it.
704-
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) {
704+
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) {
705705
ut := userType(rt)
706+
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
707+
// Return the pointer to the op we're already building.
708+
if opPtr := inProgress[rt]; opPtr != nil {
709+
return opPtr, ut.indir
710+
}
706711
typ := ut.base
707712
indir := ut.indir
708713
var op decOp
709714
k := typ.Kind()
710-
if int(k) < len(decOpMap) {
711-
op = decOpMap[k]
715+
if int(k) < len(decOpTable) {
716+
op = decOpTable[k]
712717
}
713718
if op == nil {
719+
inProgress[rt] = &op
714720
// Special cases
715721
switch t := typ.(type) {
716722
case *reflect.ArrayType:
717723
name = "element of " + name
718724
elemId := dec.wireType[wireId].ArrayT.Elem
719-
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
725+
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
720726
ovfl := overflow(name)
721727
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
722-
state.dec.decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
728+
state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
723729
}
724730

725731
case *reflect.MapType:
726732
name = "element of " + name
727733
keyId := dec.wireType[wireId].MapT.Key
728734
elemId := dec.wireType[wireId].MapT.Elem
729-
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name)
730-
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
735+
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress)
736+
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
731737
ovfl := overflow(name)
732738
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
733739
up := unsafe.Pointer(p)
734-
state.dec.decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl)
740+
state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl)
735741
}
736742

737743
case *reflect.SliceType:
@@ -746,10 +752,10 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
746752
} else {
747753
elemId = dec.wireType[wireId].SliceT.Elem
748754
}
749-
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
755+
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
750756
ovfl := overflow(name)
751757
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
752-
state.dec.decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
758+
state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
753759
}
754760

755761
case *reflect.StructType:
@@ -774,7 +780,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
774780
if op == nil {
775781
errorf("gob: decode can't handle type %s", rt.String())
776782
}
777-
return op, indir
783+
return &op, indir
778784
}
779785

780786
// Return the decoding op for a field that has no destination.
@@ -838,11 +844,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
838844
// Are these two gob Types compatible?
839845
// Answers the question for basic types, arrays, and slices.
840846
// Structs are considered ok; fields will be checked later.
841-
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
847+
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
848+
if rhs, ok := inProgress[fr]; ok {
849+
return rhs == fw
850+
}
851+
inProgress[fr] = fw
842852
fr = userType(fr).base
843853
switch t := fr.(type) {
844854
default:
845-
// map, chan, etc: cannot handle.
855+
// chan, etc: cannot handle.
846856
return false
847857
case *reflect.BoolType:
848858
return fw == tBool
@@ -864,14 +874,14 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
864874
return false
865875
}
866876
array := wire.ArrayT
867-
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem)
877+
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
868878
case *reflect.MapType:
869879
wire, ok := dec.wireType[fw]
870880
if !ok || wire.MapT == nil {
871881
return false
872882
}
873883
MapType := wire.MapT
874-
return dec.compatibleType(t.Key(), MapType.Key) && dec.compatibleType(t.Elem(), MapType.Elem)
884+
return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress)
875885
case *reflect.SliceType:
876886
// Is it an array of bytes?
877887
if t.Elem().Kind() == reflect.Uint8 {
@@ -885,7 +895,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
885895
sw = dec.wireType[fw].SliceT
886896
}
887897
elem := userType(t.Elem()).base
888-
return sw != nil && dec.compatibleType(elem, sw.Elem)
898+
return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
889899
case *reflect.StructType:
890900
return true
891901
}
@@ -906,12 +916,12 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
906916
engine = new(decEngine)
907917
engine.instr = make([]decInstr, 1) // one item
908918
name := rt.String() // best we can do
909-
if !dec.compatibleType(rt, remoteId) {
919+
if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) {
910920
return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId))
911921
}
912-
op, indir := dec.decOpFor(remoteId, rt, name)
922+
op, indir := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp))
913923
ovfl := os.ErrorString(`value for "` + name + `" out of range`)
914-
engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl}
924+
engine.instr[singletonField] = decInstr{*op, singletonField, indir, 0, ovfl}
915925
engine.numInstr = 1
916926
return
917927
}
@@ -954,6 +964,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
954964
}
955965
engine = new(decEngine)
956966
engine.instr = make([]decInstr, len(wireStruct.Field))
967+
seen := make(map[reflect.Type]*decOp)
957968
// Loop over the fields of the wire type.
958969
for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ {
959970
wireField := wireStruct.Field[fieldnum]
@@ -969,11 +980,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
969980
engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl}
970981
continue
971982
}
972-
if !dec.compatibleType(localField.Type, wireField.Id) {
983+
if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) {
973984
errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name)
974985
}
975-
op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name)
976-
engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl}
986+
op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen)
987+
engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl}
977988
engine.numInstr++
978989
}
979990
return
@@ -1070,8 +1081,8 @@ func init() {
10701081
default:
10711082
panic("gob: unknown size of int/uint")
10721083
}
1073-
decOpMap[reflect.Int] = iop
1074-
decOpMap[reflect.Uint] = uop
1084+
decOpTable[reflect.Int] = iop
1085+
decOpTable[reflect.Uint] = uop
10751086

10761087
// Finally uintptr
10771088
switch reflect.Typeof(uintptr(0)).Bits() {
@@ -1082,5 +1093,5 @@ func init() {
10821093
default:
10831094
panic("gob: unknown size of uintptr")
10841095
}
1085-
decOpMap[reflect.Uintptr] = uop
1096+
decOpTable[reflect.Uintptr] = uop
10861097
}

Diff for: src/pkg/gob/encode.go

+25-18
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
414414
}
415415
}
416416

417-
var encOpMap = []encOp{
417+
var encOpTable = [...]encOp{
418418
reflect.Bool: encBool,
419419
reflect.Int: encInt,
420420
reflect.Int8: encInt8,
@@ -434,18 +434,24 @@ var encOpMap = []encOp{
434434
reflect.String: encString,
435435
}
436436

437-
// Return the encoding op for the base type under rt and
437+
// Return (a pointer to) the encoding op for the base type under rt and
438438
// the indirection count to reach it.
439-
func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
439+
func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) {
440440
ut := userType(rt)
441+
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
442+
// Return the pointer to the op we're already building.
443+
if opPtr := inProgress[rt]; opPtr != nil {
444+
return opPtr, ut.indir
445+
}
441446
typ := ut.base
442447
indir := ut.indir
443-
var op encOp
444448
k := typ.Kind()
445-
if int(k) < len(encOpMap) {
446-
op = encOpMap[k]
449+
var op encOp
450+
if int(k) < len(encOpTable) {
451+
op = encOpTable[k]
447452
}
448453
if op == nil {
454+
inProgress[rt] = &op
449455
// Special cases
450456
switch t := typ.(type) {
451457
case *reflect.SliceType:
@@ -454,25 +460,25 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
454460
break
455461
}
456462
// Slices have a header; we decode it to find the underlying array.
457-
elemOp, indir := enc.encOpFor(t.Elem())
463+
elemOp, indir := enc.encOpFor(t.Elem(), inProgress)
458464
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
459465
slice := (*reflect.SliceHeader)(p)
460466
if !state.sendZero && slice.Len == 0 {
461467
return
462468
}
463469
state.update(i)
464-
state.enc.encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len))
470+
state.enc.encodeArray(state.b, slice.Data, *elemOp, t.Elem().Size(), indir, int(slice.Len))
465471
}
466472
case *reflect.ArrayType:
467473
// True arrays have size in the type.
468-
elemOp, indir := enc.encOpFor(t.Elem())
474+
elemOp, indir := enc.encOpFor(t.Elem(), inProgress)
469475
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
470476
state.update(i)
471-
state.enc.encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
477+
state.enc.encodeArray(state.b, uintptr(p), *elemOp, t.Elem().Size(), indir, t.Len())
472478
}
473479
case *reflect.MapType:
474-
keyOp, keyIndir := enc.encOpFor(t.Key())
475-
elemOp, elemIndir := enc.encOpFor(t.Elem())
480+
keyOp, keyIndir := enc.encOpFor(t.Key(), inProgress)
481+
elemOp, elemIndir := enc.encOpFor(t.Elem(), inProgress)
476482
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
477483
// Maps cannot be accessed by moving addresses around the way
478484
// that slices etc. can. We must recover a full reflection value for
@@ -483,7 +489,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
483489
return
484490
}
485491
state.update(i)
486-
state.enc.encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir)
492+
state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir)
487493
}
488494
case *reflect.StructType:
489495
// Generate a closure that calls out to the engine for the nested type.
@@ -511,30 +517,31 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
511517
if op == nil {
512518
errorf("gob enc: can't happen: encode type %s", rt.String())
513519
}
514-
return op, indir
520+
return &op, indir
515521
}
516522

517523
// The local Type was compiled from the actual value, so we know it's compatible.
518524
func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
519525
srt, isStruct := rt.(*reflect.StructType)
520526
engine := new(encEngine)
527+
seen := make(map[reflect.Type]*encOp)
521528
if isStruct {
522529
for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ {
523530
f := srt.Field(fieldNum)
524531
if !isExported(f.Name) {
525532
continue
526533
}
527-
op, indir := enc.encOpFor(f.Type)
528-
engine.instr = append(engine.instr, encInstr{op, fieldNum, indir, uintptr(f.Offset)})
534+
op, indir := enc.encOpFor(f.Type, seen)
535+
engine.instr = append(engine.instr, encInstr{*op, fieldNum, indir, uintptr(f.Offset)})
529536
}
530537
if srt.NumField() > 0 && len(engine.instr) == 0 {
531538
errorf("type %s has no exported fields", rt)
532539
}
533540
engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0})
534541
} else {
535542
engine.instr = make([]encInstr, 1)
536-
op, indir := enc.encOpFor(rt)
537-
engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero
543+
op, indir := enc.encOpFor(rt, seen)
544+
engine.instr[0] = encInstr{*op, singletonField, indir, 0} // offset is zero
538545
}
539546
return engine
540547
}

Diff for: src/pkg/gob/encoder_test.go

+18
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,24 @@ func TestArray(t *testing.T) {
249249
}
250250
}
251251

252+
func TestRecursiveMapType(t *testing.T) {
253+
type recursiveMap map[string]recursiveMap
254+
r1 := recursiveMap{"A": recursiveMap{"B": nil, "C": nil}, "D": nil}
255+
r2 := make(recursiveMap)
256+
if err := encAndDec(r1, &r2); err != nil {
257+
t.Error(err)
258+
}
259+
}
260+
261+
func TestRecursiveSliceType(t *testing.T) {
262+
type recursiveSlice []recursiveSlice
263+
r1 := recursiveSlice{0: recursiveSlice{0: nil}, 1: nil}
264+
r2 := make(recursiveSlice, 0)
265+
if err := encAndDec(r1, &r2); err != nil {
266+
t.Error(err)
267+
}
268+
}
269+
252270
// Regression test for bug: must send zero values inside arrays
253271
func TestDefaultsInArray(t *testing.T) {
254272
type Type7 struct {

0 commit comments

Comments
 (0)