Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions go/core/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ type flowContext struct {
// DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out.
func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}] {
return (*Flow[In, Out, struct{}])(DefineAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) {
fc := &flowContext{}
fc := &flowContext{
flowName: name,
}
ctx = flowContextKey.NewContext(ctx, fc)
return fn(ctx, input)
}))
Expand All @@ -65,7 +67,9 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo
// Otherwise, it should ignore the callback and just return a result.
func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream] {
return (*Flow[In, Out, Stream])(DefineStreamingAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) {
fc := &flowContext{}
fc := &flowContext{
flowName: name,
}
ctx = flowContextKey.NewContext(ctx, fc)
return fn(ctx, input, cb)
}))
Expand Down
23 changes: 23 additions & 0 deletions go/core/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,26 @@ func TestRunFlow(t *testing.T) {
t.Errorf("got %d, want %d", got, want)
}
}

func TestFlowNameFromContext(t *testing.T) {
r := registry.New()
flows := []*Flow[struct{}, string, struct{}]{
DefineFlow(r, "DefineFlow", func(ctx context.Context, _ struct{}) (string, error) {
return FlowNameFromContext(ctx), nil
}),
DefineStreamingFlow(r, "DefineStreamingFlow", func(ctx context.Context, _ struct{}, s StreamCallback[struct{}]) (string, error) {
return FlowNameFromContext(ctx), nil
}),
}
for _, flow := range flows {
t.Run(flow.Name(), func(t *testing.T) {
got, err := flow.Run(context.Background(), struct{}{})
if err != nil {
t.Fatal(err)
}
if want := flow.Name(); got != want {
t.Errorf("got '%s', want '%s'", got, want)
}
})
}
}
Loading