@@ -11,6 +11,7 @@ import (
1111 "net/http"
1212 "strings"
1313
14+ shimjson "github.com/openai/openai-go/v3/internal/encoding/json"
1415 "github.com/tidwall/gjson"
1516)
1617
@@ -134,10 +135,11 @@ func (s *eventStreamDecoder) Err() error {
134135}
135136
136137type Stream [T any ] struct {
137- decoder Decoder
138- cur T
139- err error
140- done bool
138+ decoder Decoder
139+ cur T
140+ err error
141+ done bool
142+ synthesizeEventData bool
141143}
142144
143145func NewStream [T any ](decoder Decoder , err error ) * Stream [T ] {
@@ -147,6 +149,14 @@ func NewStream[T any](decoder Decoder, err error) *Stream[T] {
147149 }
148150}
149151
152+ func NewStreamWithSynthesizeEventData [T any ](decoder Decoder , err error ) * Stream [T ] {
153+ return & Stream [T ]{
154+ decoder : decoder ,
155+ err : err ,
156+ synthesizeEventData : true ,
157+ }
158+ }
159+
150160// Next returns false if the stream has ended or an error occurred.
151161// Call Stream.Current() to get the current value.
152162// Call Stream.Err() to get the error.
@@ -183,34 +193,32 @@ func (s *Stream[T]) Next() bool {
183193 return false
184194 }
185195 var nxt T
186-
187- if s .decoder .Event ().Type == "" || ! strings .HasPrefix (s .decoder .Event ().Type , "thread." ) {
188- ep := gjson .GetBytes (s .decoder .Event ().Data , "error" )
189- if ep .Exists () {
190- s .err = fmt .Errorf ("received error while streaming: %s" , ep .String ())
191- return false
196+ data := s .decoder .Event ().Data
197+ if s .decoder .Event ().Type != "" && strings .HasPrefix (s .decoder .Event ().Type , "thread." ) {
198+ synthesized := map [string ]any {
199+ "event" : s .decoder .Event ().Type ,
200+ "data" : json .RawMessage (data ),
192201 }
193- s .err = json . Unmarshal ( s . decoder . Event (). Data , & nxt )
202+ data , s .err = shimjson . Marshal ( synthesized )
194203 if s .err != nil {
195204 return false
196205 }
197- s .cur = nxt
198- return true
199- } else {
200- ep := gjson .GetBytes (s .decoder .Event ().Data , "error" )
201- if ep .Exists () {
202- s .err = fmt .Errorf ("received error while streaming: %s" , ep .String ())
203- return false
206+ } else if s .synthesizeEventData {
207+ synthesized := map [string ]any {
208+ "event" : s .decoder .Event ().Type ,
209+ "data" : json .RawMessage (data ),
204210 }
205- event := s .decoder .Event ().Type
206- data := s .decoder .Event ().Data
207- s .err = json .Unmarshal ([]byte (fmt .Sprintf (`{ "event": %q, "data": %s }` , event , data )), & nxt )
211+ data , s .err = shimjson .Marshal (synthesized )
208212 if s .err != nil {
209213 return false
210214 }
211- s .cur = nxt
212- return true
213215 }
216+ s .err = json .Unmarshal (data , & nxt )
217+ if s .err != nil {
218+ return false
219+ }
220+ s .cur = nxt
221+ return true
214222 }
215223
216224 // decoder.Next() may be false because of an error
0 commit comments