Skip to content

Commit 67e83f7

Browse files
authored
fix(go): added bucketing of float64s based on input schema (#3552)
1 parent 424692b commit 67e83f7

File tree

3 files changed

+128
-7
lines changed

3 files changed

+128
-7
lines changed

go/core/action.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,28 +233,45 @@ func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.Raw
233233
if err := base.ValidateJSON(input, a.desc.InputSchema); err != nil {
234234
return nil, NewError(INVALID_ARGUMENT, err.Error())
235235
}
236-
var in In
237-
if input != nil {
238-
json.Unmarshal(input, &in)
236+
237+
var i In
238+
if len(input) > 0 {
239+
if err := json.Unmarshal(input, &i); err != nil {
240+
return nil, NewError(INVALID_ARGUMENT, "invalid input: %v", err)
241+
}
242+
243+
// Adhere to the input schema if the number type is ambiguous and the input type is an any.
244+
converted, err := base.ConvertJSONNumbers(i, a.desc.InputSchema)
245+
if err != nil {
246+
return nil, NewError(INVALID_ARGUMENT, "invalid input: %v", err)
247+
}
248+
249+
if result, ok := converted.(In); ok {
250+
i = result
251+
}
239252
}
240-
var callback func(context.Context, Stream) error
253+
254+
var scb StreamCallback[Stream]
241255
if cb != nil {
242-
callback = func(ctx context.Context, s Stream) error {
256+
scb = func(ctx context.Context, s Stream) error {
243257
bytes, err := json.Marshal(s)
244258
if err != nil {
245259
return err
246260
}
247261
return cb(ctx, json.RawMessage(bytes))
248262
}
249263
}
250-
out, err := a.Run(ctx, in, callback)
264+
265+
out, err := a.Run(ctx, i, scb)
251266
if err != nil {
252267
return nil, err
253268
}
269+
254270
bytes, err := json.Marshal(out)
255271
if err != nil {
256272
return nil, err
257273
}
274+
258275
return json.RawMessage(bytes), nil
259276
}
260277

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
17+
package base
18+
19+
import (
20+
"fmt"
21+
)
22+
23+
// ConvertJSONNumbers recursively traverses a data structure and a corresponding JSON schema.
24+
// It converts instances of float64 into int64 or float64 based on the schema's "type" property.
25+
func ConvertJSONNumbers(data any, schema map[string]any) (any, error) {
26+
if data == nil || schema == nil {
27+
return data, nil
28+
}
29+
30+
switch d := data.(type) {
31+
case float64:
32+
return convertFloat64(d, schema)
33+
case map[string]any:
34+
return convertObjectNumbers(d, schema)
35+
case []any:
36+
return convertArrayNumbers(d, schema)
37+
default:
38+
return data, nil
39+
}
40+
}
41+
42+
// convertFloat64 converts a float64 to an int64 or float64 based on the schema's "type" property.
43+
func convertFloat64(f float64, schema map[string]any) (any, error) {
44+
schemaType, ok := schema["type"].(string)
45+
if !ok {
46+
return f, nil // No type specified, leave as float64
47+
}
48+
49+
switch schemaType {
50+
case "integer":
51+
// Convert float64 to int64 if it represents a whole number
52+
if f == float64(int64(f)) {
53+
return int64(f), nil
54+
}
55+
return nil, fmt.Errorf("cannot convert %f to integer: not a whole number", f)
56+
case "number":
57+
return f, nil // Already a float64
58+
default:
59+
return f, nil // Not a numeric type, leave as is
60+
}
61+
}
62+
63+
// convertObjectNumbers converts any float64s in the map values to int64 or float64 based on the schema's "type" property.
64+
func convertObjectNumbers(obj map[string]any, schema map[string]any) (map[string]any, error) {
65+
props, ok := schema["properties"].(map[string]any)
66+
if !ok {
67+
return obj, nil // No properties to guide conversion
68+
}
69+
70+
newObj := make(map[string]any, len(obj))
71+
for k, v := range obj {
72+
newObj[k] = v // Copy original value
73+
74+
propSchema, ok := props[k].(map[string]any)
75+
if !ok {
76+
continue // No schema for this property
77+
}
78+
79+
converted, err := ConvertJSONNumbers(v, propSchema)
80+
if err != nil {
81+
return nil, err
82+
}
83+
newObj[k] = converted
84+
}
85+
return newObj, nil
86+
}
87+
88+
// convertArrayNumbers converts any float64s in the array values to int64 or float64 based on the schema's "type" property.
89+
func convertArrayNumbers(arr []any, schema map[string]any) ([]any, error) {
90+
items, ok := schema["items"].(map[string]any)
91+
if !ok {
92+
return arr, nil // No items schema to guide conversion
93+
}
94+
95+
newArr := make([]any, len(arr))
96+
for i, v := range arr {
97+
converted, err := ConvertJSONNumbers(v, items)
98+
if err != nil {
99+
return nil, err
100+
}
101+
newArr[i] = converted
102+
}
103+
return newArr, nil
104+
}

go/internal/version.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ package internal
1818

1919
// Version is the current tagged release of this module.
2020
// That is, it should match the value of the latest `go/v*` git tag.
21-
const Version = "1.0.0"
21+
const Version = "1.0.2"
2222

2323
const GENKIT_REFLECTION_API_SPEC_VERSION = 1

0 commit comments

Comments
 (0)