Skip to content

Commit cd7fb73

Browse files
committed
Support context for MarshalJSON and UnmarshalJSON
1 parent a2ba5e8 commit cd7fb73

17 files changed

+355
-51
lines changed

decode.go

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package json
22

33
import (
4+
"context"
45
"fmt"
56
"io"
67
"reflect"
@@ -39,7 +40,7 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
3940
}
4041
ctx := decoder.TakeRuntimeContext()
4142
ctx.Buf = src
42-
ctx.Option.Flag = 0
43+
ctx.Option.Flags = 0
4344
for _, optFunc := range optFuncs {
4445
optFunc(ctx.Option)
4546
}
@@ -52,6 +53,36 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
5253
return validateEndBuf(src, cursor)
5354
}
5455

56+
func unmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
57+
src := make([]byte, len(data)+1) // append nul byte to the end
58+
copy(src, data)
59+
60+
header := (*emptyInterface)(unsafe.Pointer(&v))
61+
62+
if err := validateType(header.typ, uintptr(header.ptr)); err != nil {
63+
return err
64+
}
65+
dec, err := decoder.CompileToGetDecoder(header.typ)
66+
if err != nil {
67+
return err
68+
}
69+
rctx := decoder.TakeRuntimeContext()
70+
rctx.Buf = src
71+
rctx.Option.Flags = 0
72+
rctx.Option.Flags |= decoder.ContextOption
73+
rctx.Option.Context = ctx
74+
for _, optFunc := range optFuncs {
75+
optFunc(rctx.Option)
76+
}
77+
cursor, err := dec.Decode(rctx, 0, 0, header.ptr)
78+
if err != nil {
79+
decoder.ReleaseRuntimeContext(rctx)
80+
return err
81+
}
82+
decoder.ReleaseRuntimeContext(rctx)
83+
return validateEndBuf(src, cursor)
84+
}
85+
5586
func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
5687
src := make([]byte, len(data)+1) // append nul byte to the end
5788
copy(src, data)
@@ -68,7 +99,7 @@ func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc)
6899

69100
ctx := decoder.TakeRuntimeContext()
70101
ctx.Buf = src
71-
ctx.Option.Flag = 0
102+
ctx.Option.Flags = 0
72103
for _, optFunc := range optFuncs {
73104
optFunc(ctx.Option)
74105
}
@@ -137,6 +168,14 @@ func (d *Decoder) Decode(v interface{}) error {
137168
return d.DecodeWithOption(v)
138169
}
139170

171+
// DecodeContext reads the next JSON-encoded value from its
172+
// input and stores it in the value pointed to by v with context.Context.
173+
func (d *Decoder) DecodeContext(ctx context.Context, v interface{}) error {
174+
d.s.Option.Flags |= decoder.ContextOption
175+
d.s.Option.Context = ctx
176+
return d.DecodeWithOption(v)
177+
}
178+
140179
func (d *Decoder) DecodeWithOption(v interface{}, optFuncs ...DecodeOptionFunc) error {
141180
header := (*emptyInterface)(unsafe.Pointer(&v))
142181
typ := header.typ

decode_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package json_test
22

33
import (
44
"bytes"
5+
"context"
56
"encoding"
67
stdjson "encoding/json"
78
"errors"
@@ -3620,3 +3621,48 @@ func TestDecodeEscapedCharField(t *testing.T) {
36203621
}
36213622
})
36223623
}
3624+
3625+
type unmarshalContextKey struct{}
3626+
3627+
type unmarshalContextStructType struct {
3628+
v int
3629+
}
3630+
3631+
func (t *unmarshalContextStructType) UnmarshalJSON(ctx context.Context, b []byte) error {
3632+
v := ctx.Value(unmarshalContextKey{})
3633+
s, ok := v.(string)
3634+
if !ok {
3635+
return fmt.Errorf("failed to propagate parent context.Context")
3636+
}
3637+
if s != "hello" {
3638+
return fmt.Errorf("failed to propagate parent context.Context")
3639+
}
3640+
t.v = 100
3641+
return nil
3642+
}
3643+
3644+
func TestDecodeContextOption(t *testing.T) {
3645+
src := []byte("10")
3646+
buf := bytes.NewBuffer(src)
3647+
3648+
t.Run("UnmarshalContext", func(t *testing.T) {
3649+
ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello")
3650+
var v unmarshalContextStructType
3651+
if err := json.UnmarshalContext(ctx, src, &v); err != nil {
3652+
t.Fatal(err)
3653+
}
3654+
if v.v != 100 {
3655+
t.Fatal("failed to decode with context")
3656+
}
3657+
})
3658+
t.Run("DecodeContext", func(t *testing.T) {
3659+
ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello")
3660+
var v unmarshalContextStructType
3661+
if err := json.NewDecoder(buf).DecodeContext(ctx, &v); err != nil {
3662+
t.Fatal(err)
3663+
}
3664+
if v.v != 100 {
3665+
t.Fatal("failed to decode with context")
3666+
}
3667+
})
3668+
}

encode.go

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package json
22

33
import (
4+
"context"
45
"io"
56
"unsafe"
67

@@ -35,15 +36,28 @@ func (e *Encoder) Encode(v interface{}) error {
3536
// EncodeWithOption call Encode with EncodeOption.
3637
func (e *Encoder) EncodeWithOption(v interface{}, optFuncs ...EncodeOptionFunc) error {
3738
ctx := encoder.TakeRuntimeContext()
39+
ctx.Option.Flag = 0
3840

3941
err := e.encodeWithOption(ctx, v, optFuncs...)
4042

4143
encoder.ReleaseRuntimeContext(ctx)
4244
return err
4345
}
4446

47+
// EncodeContext call Encode with context.Context and EncodeOption.
48+
func (e *Encoder) EncodeContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) error {
49+
rctx := encoder.TakeRuntimeContext()
50+
rctx.Option.Flag = 0
51+
rctx.Option.Flag |= encoder.ContextOption
52+
rctx.Option.Context = ctx
53+
54+
err := e.encodeWithOption(rctx, v, optFuncs...)
55+
56+
encoder.ReleaseRuntimeContext(rctx)
57+
return err
58+
}
59+
4560
func (e *Encoder) encodeWithOption(ctx *encoder.RuntimeContext, v interface{}, optFuncs ...EncodeOptionFunc) error {
46-
ctx.Option.Flag = 0
4761
if e.enabledHTMLEscape {
4862
ctx.Option.Flag |= encoder.HTMLEscapeOption
4963
}
@@ -94,6 +108,33 @@ func (e *Encoder) SetIndent(prefix, indent string) {
94108
e.enabledIndent = true
95109
}
96110

111+
func marshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
112+
rctx := encoder.TakeRuntimeContext()
113+
rctx.Option.Flag = 0
114+
rctx.Option.Flag = encoder.HTMLEscapeOption | encoder.ContextOption
115+
rctx.Option.Context = ctx
116+
for _, optFunc := range optFuncs {
117+
optFunc(rctx.Option)
118+
}
119+
120+
buf, err := encode(rctx, v)
121+
if err != nil {
122+
encoder.ReleaseRuntimeContext(rctx)
123+
return nil, err
124+
}
125+
126+
// this line exists to escape call of `runtime.makeslicecopy` .
127+
// if use `make([]byte, len(buf)-1)` and `copy(copied, buf)`,
128+
// dst buffer size and src buffer size are differrent.
129+
// in this case, compiler uses `runtime.makeslicecopy`, but it is slow.
130+
buf = buf[:len(buf)-1]
131+
copied := make([]byte, len(buf))
132+
copy(copied, buf)
133+
134+
encoder.ReleaseRuntimeContext(rctx)
135+
return copied, nil
136+
}
137+
97138
func marshal(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
98139
ctx := encoder.TakeRuntimeContext()
99140

encode_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package json_test
22

33
import (
44
"bytes"
5+
"context"
56
"encoding"
67
stdjson "encoding/json"
78
"errors"
@@ -1918,3 +1919,42 @@ func TestEncodeMapKeyTypeInterface(t *testing.T) {
19181919
t.Fatal("expected error")
19191920
}
19201921
}
1922+
1923+
type marshalContextKey struct{}
1924+
1925+
type marshalContextStructType struct{}
1926+
1927+
func (t *marshalContextStructType) MarshalJSON(ctx context.Context) ([]byte, error) {
1928+
v := ctx.Value(marshalContextKey{})
1929+
s, ok := v.(string)
1930+
if !ok {
1931+
return nil, fmt.Errorf("failed to propagate parent context.Context")
1932+
}
1933+
if s != "hello" {
1934+
return nil, fmt.Errorf("failed to propagate parent context.Context")
1935+
}
1936+
return []byte(`"success"`), nil
1937+
}
1938+
1939+
func TestEncodeContextOption(t *testing.T) {
1940+
t.Run("MarshalContext", func(t *testing.T) {
1941+
ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello")
1942+
b, err := json.MarshalContext(ctx, &marshalContextStructType{})
1943+
if err != nil {
1944+
t.Fatal(err)
1945+
}
1946+
if string(b) != `"success"` {
1947+
t.Fatal("failed to encode with MarshalerContext")
1948+
}
1949+
})
1950+
t.Run("EncodeContext", func(t *testing.T) {
1951+
ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello")
1952+
buf := bytes.NewBuffer([]byte{})
1953+
if err := json.NewEncoder(buf).EncodeContext(ctx, &marshalContextStructType{}); err != nil {
1954+
t.Fatal(err)
1955+
}
1956+
if buf.String() != "\"success\"\n" {
1957+
t.Fatal("failed to encode with EncodeContext")
1958+
}
1959+
})
1960+
}

internal/decoder/compile.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func compileToGetDecoderSlowPath(typeptr uintptr, typ *runtime.Type) (Decoder, e
6060

6161
func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
6262
switch {
63-
case runtime.PtrTo(typ).Implements(unmarshalJSONType):
63+
case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
6464
return newUnmarshalJSONDecoder(runtime.PtrTo(typ), "", ""), nil
6565
case runtime.PtrTo(typ).Implements(unmarshalTextType):
6666
return newUnmarshalTextDecoder(runtime.PtrTo(typ), "", ""), nil
@@ -70,7 +70,7 @@ func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (De
7070

7171
func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
7272
switch {
73-
case runtime.PtrTo(typ).Implements(unmarshalJSONType):
73+
case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
7474
return newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName), nil
7575
case runtime.PtrTo(typ).Implements(unmarshalTextType):
7676
return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
@@ -133,7 +133,7 @@ func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecode
133133

134134
func isStringTagSupportedType(typ *runtime.Type) bool {
135135
switch {
136-
case runtime.PtrTo(typ).Implements(unmarshalJSONType):
136+
case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
137137
return false
138138
case runtime.PtrTo(typ).Implements(unmarshalTextType):
139139
return false
@@ -494,3 +494,7 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
494494
structDec.tryOptimize()
495495
return structDec, nil
496496
}
497+
498+
func implementsUnmarshalJSONType(typ *runtime.Type) bool {
499+
return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType)
500+
}

internal/decoder/interface.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,21 @@ func decodeStreamUnmarshaler(s *Stream, depth int64, unmarshaler json.Unmarshale
117117
return nil
118118
}
119119

120+
func decodeStreamUnmarshalerContext(s *Stream, depth int64, unmarshaler unmarshalerContext) error {
121+
start := s.cursor
122+
if err := s.skipValue(depth); err != nil {
123+
return err
124+
}
125+
src := s.buf[start:s.cursor]
126+
dst := make([]byte, len(src))
127+
copy(dst, src)
128+
129+
if err := unmarshaler.UnmarshalJSON(s.Option.Context, dst); err != nil {
130+
return err
131+
}
132+
return nil
133+
}
134+
120135
func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarshaler) (int64, error) {
121136
cursor = skipWhiteSpace(buf, cursor)
122137
start := cursor
@@ -134,6 +149,23 @@ func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarsh
134149
return end, nil
135150
}
136151

152+
func decodeUnmarshalerContext(ctx *RuntimeContext, buf []byte, cursor, depth int64, unmarshaler unmarshalerContext) (int64, error) {
153+
cursor = skipWhiteSpace(buf, cursor)
154+
start := cursor
155+
end, err := skipValue(buf, cursor, depth)
156+
if err != nil {
157+
return 0, err
158+
}
159+
src := buf[start:end]
160+
dst := make([]byte, len(src))
161+
copy(dst, src)
162+
163+
if err := unmarshaler.UnmarshalJSON(ctx.Option.Context, dst); err != nil {
164+
return 0, err
165+
}
166+
return end, nil
167+
}
168+
137169
func decodeStreamTextUnmarshaler(s *Stream, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error {
138170
start := s.cursor
139171
if err := s.skipValue(depth); err != nil {
@@ -260,6 +292,9 @@ func (d *interfaceDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer
260292
}))
261293
rv := reflect.ValueOf(runtimeInterfaceValue)
262294
if rv.NumMethod() > 0 && rv.CanInterface() {
295+
if u, ok := rv.Interface().(unmarshalerContext); ok {
296+
return decodeStreamUnmarshalerContext(s, depth, u)
297+
}
263298
if u, ok := rv.Interface().(json.Unmarshaler); ok {
264299
return decodeStreamUnmarshaler(s, depth, u)
265300
}
@@ -317,6 +352,9 @@ func (d *interfaceDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p un
317352
}))
318353
rv := reflect.ValueOf(runtimeInterfaceValue)
319354
if rv.NumMethod() > 0 && rv.CanInterface() {
355+
if u, ok := rv.Interface().(unmarshalerContext); ok {
356+
return decodeUnmarshalerContext(ctx, buf, cursor, depth, u)
357+
}
320358
if u, ok := rv.Interface().(json.Unmarshaler); ok {
321359
return decodeUnmarshaler(buf, cursor, depth, u)
322360
}

internal/decoder/option.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package decoder
22

3-
type OptionFlag int
3+
import "context"
4+
5+
type OptionFlags uint8
46

57
const (
6-
FirstWinOption OptionFlag = 1 << iota
8+
FirstWinOption OptionFlags = 1 << iota
9+
ContextOption
710
)
811

912
type Option struct {
10-
Flag OptionFlag
13+
Flags OptionFlags
14+
Context context.Context
1115
}

internal/decoder/struct.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ func (d *structDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) e
665665
seenFields map[int]struct{}
666666
seenFieldNum int
667667
)
668-
firstWin := (s.Option.Flag & FirstWinOption) != 0
668+
firstWin := (s.Option.Flags & FirstWinOption) != 0
669669
if firstWin {
670670
seenFields = make(map[int]struct{}, d.fieldUniqueNameNum)
671671
}
@@ -752,7 +752,7 @@ func (d *structDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
752752
seenFields map[int]struct{}
753753
seenFieldNum int
754754
)
755-
firstWin := (ctx.Option.Flag & FirstWinOption) != 0
755+
firstWin := (ctx.Option.Flags & FirstWinOption) != 0
756756
if firstWin {
757757
seenFields = make(map[int]struct{}, d.fieldUniqueNameNum)
758758
}

0 commit comments

Comments
 (0)