Skip to content

Commit 45dc6bb

Browse files
chore(internal): refactor sse event parsing
1 parent 8b7d268 commit 45dc6bb

File tree

3 files changed

+34
-26
lines changed

3 files changed

+34
-26
lines changed

betathread.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (r *BetaThreadService) NewAndRunStreaming(ctx context.Context, body BetaThr
127127
opts = append(opts, option.WithJSONSet("stream", true))
128128
path := "threads/runs"
129129
err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...)
130-
return ssestream.NewStream[AssistantStreamEventUnion](ssestream.NewDecoder(raw), err)
130+
return ssestream.NewStreamWithSynthesizeEventData[AssistantStreamEventUnion](ssestream.NewDecoder(raw), err)
131131
}
132132

133133
// AssistantResponseFormatOptionUnion contains all possible properties and values

betathreadrun.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func (r *BetaThreadRunService) NewStreaming(ctx context.Context, threadID string
7878
}
7979
path := fmt.Sprintf("threads/%s/runs", threadID)
8080
err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, params, &raw, opts...)
81-
return ssestream.NewStream[AssistantStreamEventUnion](ssestream.NewDecoder(raw), err)
81+
return ssestream.NewStreamWithSynthesizeEventData[AssistantStreamEventUnion](ssestream.NewDecoder(raw), err)
8282
}
8383

8484
// Retrieves a run.
@@ -215,7 +215,7 @@ func (r *BetaThreadRunService) SubmitToolOutputsStreaming(ctx context.Context, t
215215
}
216216
path := fmt.Sprintf("threads/%s/runs/%s/submit_tool_outputs", threadID, runID)
217217
err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...)
218-
return ssestream.NewStream[AssistantStreamEventUnion](ssestream.NewDecoder(raw), err)
218+
return ssestream.NewStreamWithSynthesizeEventData[AssistantStreamEventUnion](ssestream.NewDecoder(raw), err)
219219
}
220220

221221
// Tool call objects

packages/ssestream/ssestream.go

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"net/http"
1212
"strings"
1313

14+
shimjson "github.com/openai/openai-go/v3/internal/encoding/json"
1415
"github.com/tidwall/gjson"
1516
)
1617

@@ -134,10 +135,11 @@ func (s *eventStreamDecoder) Err() error {
134135
}
135136

136137
type Stream[T any] struct {
137-
decoder Decoder
138-
cur T
139-
err error
140-
done bool
138+
decoder Decoder
139+
cur T
140+
err error
141+
done bool
142+
synthesizeEventData bool
141143
}
142144

143145
func NewStream[T any](decoder Decoder, err error) *Stream[T] {
@@ -147,6 +149,14 @@ func NewStream[T any](decoder Decoder, err error) *Stream[T] {
147149
}
148150
}
149151

152+
func NewStreamWithSynthesizeEventData[T any](decoder Decoder, err error) *Stream[T] {
153+
return &Stream[T]{
154+
decoder: decoder,
155+
err: err,
156+
synthesizeEventData: true,
157+
}
158+
}
159+
150160
// Next returns false if the stream has ended or an error occurred.
151161
// Call Stream.Current() to get the current value.
152162
// Call Stream.Err() to get the error.
@@ -183,34 +193,32 @@ func (s *Stream[T]) Next() bool {
183193
return false
184194
}
185195
var nxt T
186-
187-
if s.decoder.Event().Type == "" || !strings.HasPrefix(s.decoder.Event().Type, "thread.") {
188-
ep := gjson.GetBytes(s.decoder.Event().Data, "error")
189-
if ep.Exists() {
190-
s.err = fmt.Errorf("received error while streaming: %s", ep.String())
191-
return false
196+
data := s.decoder.Event().Data
197+
if s.decoder.Event().Type != "" && strings.HasPrefix(s.decoder.Event().Type, "thread.") {
198+
synthesized := map[string]any{
199+
"event": s.decoder.Event().Type,
200+
"data": json.RawMessage(data),
192201
}
193-
s.err = json.Unmarshal(s.decoder.Event().Data, &nxt)
202+
data, s.err = shimjson.Marshal(synthesized)
194203
if s.err != nil {
195204
return false
196205
}
197-
s.cur = nxt
198-
return true
199-
} else {
200-
ep := gjson.GetBytes(s.decoder.Event().Data, "error")
201-
if ep.Exists() {
202-
s.err = fmt.Errorf("received error while streaming: %s", ep.String())
203-
return false
206+
} else if s.synthesizeEventData {
207+
synthesized := map[string]any{
208+
"event": s.decoder.Event().Type,
209+
"data": json.RawMessage(data),
204210
}
205-
event := s.decoder.Event().Type
206-
data := s.decoder.Event().Data
207-
s.err = json.Unmarshal([]byte(fmt.Sprintf(`{ "event": %q, "data": %s }`, event, data)), &nxt)
211+
data, s.err = shimjson.Marshal(synthesized)
208212
if s.err != nil {
209213
return false
210214
}
211-
s.cur = nxt
212-
return true
213215
}
216+
s.err = json.Unmarshal(data, &nxt)
217+
if s.err != nil {
218+
return false
219+
}
220+
s.cur = nxt
221+
return true
214222
}
215223

216224
// decoder.Next() may be false because of an error

0 commit comments

Comments
 (0)