Skip to content

Commit 902fd6a

Browse files
authored
Merge pull request #229 from goccy/feature/keep-original-slice-reference
Keep original reference of slice element
2 parents d3951e3 + 90d4d18 commit 902fd6a

File tree

2 files changed

+65
-19
lines changed

2 files changed

+65
-19
lines changed

decode_slice.go

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,20 @@ func newSliceDecoder(dec decoder, elemType *rtype, size uintptr, structName, fie
4949
}
5050
}
5151

52-
func (d *sliceDecoder) newSlice() *sliceHeader {
52+
func (d *sliceDecoder) newSlice(src *sliceHeader) *sliceHeader {
5353
slice := d.arrayPool.Get().(*sliceHeader)
54-
slice.len = 0
54+
if src.len > 0 {
55+
// copy original elem
56+
if slice.cap < src.cap {
57+
data := newArray(d.elemType, src.cap)
58+
slice = &sliceHeader{data: data, len: src.len, cap: src.cap}
59+
} else {
60+
slice.len = src.len
61+
}
62+
copySlice(d.elemType, *slice, *src)
63+
} else {
64+
slice.len = 0
65+
}
5566
return slice
5667
}
5768

@@ -109,7 +120,8 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er
109120
return nil
110121
}
111122
idx := 0
112-
slice := d.newSlice()
123+
slice := d.newSlice((*sliceHeader)(p))
124+
srcLen := slice.len
113125
capacity := slice.cap
114126
data := slice.data
115127
for {
@@ -121,12 +133,17 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er
121133
copySlice(d.elemType, dst, src)
122134
}
123135
ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size)
124-
if d.isElemPointerType {
125-
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
126-
} else {
127-
// assign new element to the slice
128-
typedmemmove(d.elemType, ep, unsafe_New(d.elemType))
136+
137+
// if srcLen is greater than idx, keep the original reference
138+
if srcLen <= idx {
139+
if d.isElemPointerType {
140+
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
141+
} else {
142+
// assign new element to the slice
143+
typedmemmove(d.elemType, ep, unsafe_New(d.elemType))
144+
}
129145
}
146+
130147
if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil {
131148
return err
132149
}
@@ -212,7 +229,8 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer)
212229
return cursor, nil
213230
}
214231
idx := 0
215-
slice := d.newSlice()
232+
slice := d.newSlice((*sliceHeader)(p))
233+
srcLen := slice.len
216234
capacity := slice.cap
217235
data := slice.data
218236
for {
@@ -224,11 +242,14 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer)
224242
copySlice(d.elemType, dst, src)
225243
}
226244
ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size)
227-
if d.isElemPointerType {
228-
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
229-
} else {
230-
// assign new element to the slice
231-
typedmemmove(d.elemType, ep, unsafe_New(d.elemType))
245+
// if srcLen is greater than idx, keep the original reference
246+
if srcLen <= idx {
247+
if d.isElemPointerType {
248+
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
249+
} else {
250+
// assign new element to the slice
251+
typedmemmove(d.elemType, ep, unsafe_New(d.elemType))
252+
}
232253
}
233254
c, err := d.valueDecoder.decode(buf, cursor, depth, ep)
234255
if err != nil {

decode_test.go

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3052,7 +3052,7 @@ func TestMultipleDecodeWithRawMessage(t *testing.T) {
30523052
type intUnmarshaler int
30533053

30543054
func (u *intUnmarshaler) UnmarshalJSON(b []byte) error {
3055-
if *u != 0 {
3055+
if *u != 0 && *u != 10 {
30563056
return fmt.Errorf("failed to decode of slice with int unmarshaler")
30573057
}
30583058
*u = 10
@@ -3062,7 +3062,7 @@ func (u *intUnmarshaler) UnmarshalJSON(b []byte) error {
30623062
type arrayUnmarshaler [5]int
30633063

30643064
func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error {
3065-
if (*u)[0] != 0 {
3065+
if (*u)[0] != 0 && (*u)[0] != 10 {
30663066
return fmt.Errorf("failed to decode of slice with array unmarshaler")
30673067
}
30683068
(*u)[0] = 10
@@ -3072,22 +3072,24 @@ func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error {
30723072
type mapUnmarshaler map[string]int
30733073

30743074
func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error {
3075-
if len(*u) != 0 {
3075+
if len(*u) != 0 && len(*u) != 1 {
30763076
return fmt.Errorf("failed to decode of slice with map unmarshaler")
30773077
}
30783078
*u = map[string]int{"a": 10}
30793079
return nil
30803080
}
30813081

30823082
type structUnmarshaler struct {
3083-
A int
3083+
A int
3084+
notFirst bool
30843085
}
30853086

30863087
func (u *structUnmarshaler) UnmarshalJSON(b []byte) error {
3087-
if u.A != 0 {
3088+
if !u.notFirst && u.A != 0 {
30883089
return fmt.Errorf("failed to decode of slice with struct unmarshaler")
30893090
}
30903091
u.A = 10
3092+
u.notFirst = true
30913093
return nil
30923094
}
30933095

@@ -3199,6 +3201,29 @@ func TestSliceElemUnmarshaler(t *testing.T) {
31993201
})
32003202
}
32013203

3204+
type keepRefTest struct {
3205+
A int
3206+
B string
3207+
}
3208+
3209+
func (t *keepRefTest) UnmarshalJSON(data []byte) error {
3210+
v := []interface{}{&t.A, &t.B}
3211+
return json.Unmarshal(data, &v)
3212+
}
3213+
3214+
func TestKeepReferenceSlice(t *testing.T) {
3215+
var v keepRefTest
3216+
if err := json.Unmarshal([]byte(`[54,"hello"]`), &v); err != nil {
3217+
t.Fatal(err)
3218+
}
3219+
if v.A != 54 {
3220+
t.Fatal("failed to keep reference for slice")
3221+
}
3222+
if v.B != "hello" {
3223+
t.Fatal("failed to keep reference for slice")
3224+
}
3225+
}
3226+
32023227
func TestInvalidTopLevelValue(t *testing.T) {
32033228
t.Run("invalid end of buffer", func(t *testing.T) {
32043229
var v struct{}

0 commit comments

Comments
 (0)