Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion common/decls/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package decls

import (
"context"
"fmt"
"strings"

Expand Down Expand Up @@ -342,6 +343,7 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
Unary: o.guardedUnaryOp(f.Name(), f.disableTypeGuards),
Binary: o.guardedBinaryOp(f.Name(), f.disableTypeGuards),
Function: o.guardedFunctionOp(f.Name(), f.disableTypeGuards),
Async: o.guardedAsyncOp(f.Name(), f.disableTypeGuards),
OperandTrait: o.OperandTrait(),
NonStrict: o.IsNonStrict(),
}
Expand All @@ -362,6 +364,7 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
Unary: f.singleton.Unary,
Binary: f.singleton.Binary,
Function: f.singleton.Function,
Async: f.singleton.Async,
OperandTrait: f.singleton.OperandTrait,
},
}
Expand All @@ -380,6 +383,7 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
Unary: overloads[0].Unary,
Binary: overloads[0].Binary,
Function: overloads[0].Function,
Async: overloads[0].Async,
NonStrict: overloads[0].NonStrict,
OperandTrait: overloads[0].OperandTrait,
}), nil
Expand Down Expand Up @@ -538,6 +542,30 @@ func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOp
}
}

// SingletonAsyncBinding creates a singleton async function definition to be used with all function overloads.
// The provided function is called in its own goroutine with the provided context. The function should
// block until the result is available, and the framework manages goroutine and channel lifecycle.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload async bindings.
func SingletonAsyncBinding(fn functions.BlockingAsyncOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
}
return func(f *FunctionDecl) (*FunctionDecl, error) {
if f.singleton != nil {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Async: wrapAsyncOp(fn),
OperandTrait: trait,
}
return f, nil
}
}

// Overload defines a new global overload with an overload id, argument types, and result type. Through the
// use of OverloadOpt options, the overload may also be configured with a binding, an operand trait, and to
// be non-strict.
Expand Down Expand Up @@ -622,6 +650,8 @@ type OverloadDecl struct {
binaryOp functions.BinaryOp
// functionOp is a catch-all for zero-arity and three-plus arity functions.
functionOp functions.FunctionOp
// asyncOp is an asynchronous function binding that returns a channel.
asyncOp functions.AsyncOp
}

// Examples returns a list of string examples for the overload.
Expand Down Expand Up @@ -750,7 +780,7 @@ func (o *OverloadDecl) SignatureOverlaps(other *OverloadDecl) bool {

// HasBinding indicates whether the overload already has a definition.
func (o *OverloadDecl) HasBinding() bool {
return o != nil && (o.unaryOp != nil || o.binaryOp != nil || o.functionOp != nil)
return o != nil && (o.unaryOp != nil || o.binaryOp != nil || o.functionOp != nil || o.asyncOp != nil)
}

// guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined.
Expand Down Expand Up @@ -792,6 +822,22 @@ func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool
}
}

// guardedAsyncOp creates an invocation guard around the provided async function binding, if one is provided.
func (o *OverloadDecl) guardedAsyncOp(funcName string, disableTypeGuards bool) functions.AsyncOp {
if o.asyncOp == nil {
return nil
}
return func(ctx context.Context, args ...ref.Val) <-chan ref.Val {
if !o.matchesRuntimeSignature(disableTypeGuards, args...) {
ch := make(chan ref.Val, 1)
ch <- MaybeNoSuchOverload(funcName, args...)
close(ch)
return ch
}
return o.asyncOp(ctx, args...)
}
}

// matchesRuntimeUnarySignature indicates whether the argument type is runtime assiganble to the overload's expected argument.
func (o *OverloadDecl) matchesRuntimeUnarySignature(disableTypeGuards bool, arg ref.Val) bool {
return matchRuntimeArgType(o.IsNonStrict(), disableTypeGuards, o.ArgTypes()[0], arg) &&
Expand Down Expand Up @@ -897,6 +943,40 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
}
}

// AsyncBinding provides the implementation of an asynchronous overload. The provided function
// is called in its own goroutine with the provided context. The function should block until
// the result is available, and the framework manages goroutine and channel lifecycle.
//
// This follows the same pattern used by gRPC-Go and other major Go frameworks where user
// code is synchronous and the framework manages concurrency.
func AsyncBinding(fn functions.BlockingAsyncOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.HasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
if o.hasLateBinding {
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
}
o.asyncOp = wrapAsyncOp(fn)
return o, nil
}
}

// wrapAsyncOp adapts a blocking function into the channel-based AsyncOp used internally.
//
// The blocking function is invoked synchronously and its result delivered on a buffered channel.
// The interpreter always invokes an AsyncOp from a dedicated goroutine, so running the blocking
// call inline here keeps the framework to a single goroutine per async call rather than spawning
// an additional one to bridge blocking-to-channel.
func wrapAsyncOp(fn functions.BlockingAsyncOp) functions.AsyncOp {
return func(ctx context.Context, args ...ref.Val) <-chan ref.Val {
ch := make(chan ref.Val, 1)
ch <- fn(ctx, args...)
close(ch)
return ch
}
}

// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
// This is useful for functions which have side-effects or are not deterministically computable.
func LateFunctionBinding() OverloadOpt {
Expand Down
144 changes: 144 additions & 0 deletions common/decls/decls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package decls

import (
"context"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -1371,6 +1372,149 @@ func TestNilVariable(t *testing.T) {
}
}

func TestAsyncBinding(t *testing.T) {
fn, err := NewFunction("async_fn",
Overload("async_fn_int", []*types.Type{types.IntType}, types.IntType,
AsyncBinding(func(ctx context.Context, args ...ref.Val) ref.Val {
return args[0]
}),
),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
}
for _, od := range fn.OverloadDecls() {
if !od.HasBinding() {
t.Errorf("Overload %s does not have binding, wanted async binding", od.ID())
}
}
bindings, err := fn.Bindings()
if err != nil {
t.Fatalf("fn.Bindings() produced an err: %v", err)
}
if len(bindings) != 2 {
t.Errorf("fn.Bindings() produced %d bindings, wanted 2", len(bindings))
}
for _, binding := range bindings {
if binding.Async == nil {
t.Fatal("binding missing Async implementation")
}

ctx := context.Background()
ch := binding.Async(ctx, types.Int(42))
select {
case res := <-ch:
if res.Equal(types.Int(42)) != types.True {
t.Errorf("async binding returned %v, wanted 42", res)
}
case <-time.After(1 * time.Second):
t.Fatal("async binding timed out")
}
}
}

func TestAsyncBindingTypeGuards(t *testing.T) {
fn, err := NewFunction("async_fn",
Overload("async_fn_int", []*types.Type{types.IntType}, types.IntType,
AsyncBinding(func(ctx context.Context, args ...ref.Val) ref.Val {
return args[0]
}),
),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
}
bindings, err := fn.Bindings()
if err != nil {
t.Fatalf("fn.Bindings() produced an err: %v", err)
}
if len(bindings) != 2 {
t.Errorf("fn.Bindings() produced %d bindings, wanted 2", len(bindings))
}
for _, binding := range bindings {
if binding.Async == nil {
t.Fatal("binding missing Async implementation")
}

ctx := context.Background()
ch := binding.Async(ctx, types.String("hello"))
select {
case res := <-ch:
if !types.IsError(res) {
t.Errorf("async binding returned %v, wanted error", res)
}
case <-time.After(1 * time.Second):
t.Fatal("async binding timed out")
}
}
}

func TestSingletonAsyncBinding(t *testing.T) {
fn, err := NewFunction("async_fn",
Overload("async_fn_int", []*types.Type{types.IntType}, types.IntType),
Overload("async_fn_string", []*types.Type{types.StringType}, types.StringType),
SingletonAsyncBinding(func(ctx context.Context, args ...ref.Val) ref.Val {
return args[0]
}),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
}
bindings, err := fn.Bindings()
if err != nil {
t.Fatalf("fn.Bindings() produced an err: %v", err)
}
if len(bindings) != 1 {
t.Errorf("fn.Bindings() produced %d bindings, wanted one", len(bindings))
}
binding := bindings[0]
if binding.Async == nil {
t.Fatal("binding missing Async implementation")
}

ctx := context.Background()
ch := binding.Async(ctx, types.String("hello"))
select {
case res := <-ch:
if res.Equal(types.String("hello")) != types.True {
t.Errorf("async binding returned %v, wanted hello", res)
}
case <-time.After(1 * time.Second):
t.Fatal("async binding timed out")
}
}

func TestAsyncBindingRedefinition(t *testing.T) {
_, err := NewFunction("async_fn",
Overload("async_fn_int", []*types.Type{types.IntType}, types.IntType,
AsyncBinding(func(ctx context.Context, args ...ref.Val) ref.Val {
return args[0]
}),
AsyncBinding(func(ctx context.Context, args ...ref.Val) ref.Val {
return args[0]
}),
),
)
if err == nil || !strings.Contains(err.Error(), "already has a binding") {
t.Errorf("NewFunction() got %v, wanted already has a binding", err)
}
}

func TestSingletonAsyncBindingRedefinition(t *testing.T) {
_, err := NewFunction("async_fn",
Overload("async_fn_int", []*types.Type{types.IntType}, types.IntType),
SingletonAsyncBinding(func(ctx context.Context, args ...ref.Val) ref.Val {
return args[0]
}),
SingletonAsyncBinding(func(ctx context.Context, args ...ref.Val) ref.Val {
return args[0]
}),
)
if err == nil || !strings.Contains(err.Error(), "already has a singleton binding") {
t.Errorf("NewFunction() got %v, wanted already has a singleton binding", err)
}
}

func testMerge(t *testing.T, funcs ...*FunctionDecl) *FunctionDecl {
t.Helper()
fn := funcs[0]
Expand Down
3 changes: 3 additions & 0 deletions interpreter/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ go_library(
name = "go_default_library",
srcs = [
"activation.go",
"async.go",
"attribute_patterns.go",
"attributes.go",
"decorators.go",
Expand Down Expand Up @@ -46,6 +47,7 @@ go_test(
name = "go_default_test",
srcs = [
"activation_test.go",
"async_test.go",
"attribute_patterns_test.go",
"attributes_test.go",
"frame_test.go",
Expand All @@ -65,6 +67,7 @@ go_test(
"//common/operators:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//parser:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
Expand Down
Loading