Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions go/ai/format_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"

"github.com/firebase/genkit/go/internal/base"
)
Expand Down Expand Up @@ -78,27 +79,33 @@ func (a arrayHandler) ParseMessage(m *Message) (*Message, error) {
return nil, errors.New("message has no content")
}

var newParts []*Part
var nonTextParts []*Part
accumulatedText := strings.Builder{}

for _, part := range m.Content {
if !part.IsText() {
newParts = append(newParts, part)
nonTextParts = append(nonTextParts, part)
} else {
lines := base.GetJsonObjectLines(part.Text)
for _, line := range lines {
var schemaBytes []byte
schemaBytes, err := json.Marshal(a.config.Schema["items"])
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
return nil, err
}

newParts = append(newParts, NewJSONPart(line))
}
accumulatedText.WriteString(part.Text)
}
}
m.Content = newParts

var newParts []*Part
lines := base.GetJsonObjectLines(accumulatedText.String())
for _, line := range lines {
var schemaBytes []byte
schemaBytes, err := json.Marshal(a.config.Schema["items"])
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
return nil, err
}

newParts = append(newParts, NewJSONPart(line))
}

m.Content = append(newParts, nonTextParts...)
}

return m, nil
Expand Down
31 changes: 19 additions & 12 deletions go/ai/format_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,31 @@ func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
return nil, errors.New("message has no content")
}

for i, part := range m.Content {
var nonTextParts []*Part
accumulatedText := strings.Builder{}
for _, part := range m.Content {
if !part.IsText() {
continue
nonTextParts = append(nonTextParts, part)
} else {
accumulatedText.WriteString(part.Text)
}
}

// replace single and double quotes
re := regexp.MustCompile(`['"]`)
clean := re.ReplaceAllString(part.Text, "")

// trim whitespace
trimmed := strings.TrimSpace(clean)
// replace single and double quotes
re := regexp.MustCompile(`['"]`)
clean := re.ReplaceAllString(accumulatedText.String(), "")

if !slices.Contains(e.enums, trimmed) {
return nil, fmt.Errorf("message %s not in list of valid enums: %s", trimmed, strings.Join(e.enums, ", "))
}
// trim whitespace
trimmed := strings.TrimSpace(clean)

m.Content[i] = NewTextPart(trimmed)
if !slices.Contains(e.enums, trimmed) {
return nil, fmt.Errorf("message %s not in list of valid enums: %s", trimmed, strings.Join(e.enums, ", "))
}

newParts := []*Part{NewTextPart(trimmed)}
newParts = append(newParts, nonTextParts...)

m.Content = newParts
}

return m, nil
Expand Down
47 changes: 28 additions & 19 deletions go/ai/format_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"

"github.com/firebase/genkit/go/internal/base"
)
Expand Down Expand Up @@ -79,30 +80,38 @@ func (j jsonHandler) ParseMessage(m *Message) (*Message, error) {
return nil, errors.New("message has no content")
}

for i, part := range m.Content {
var nonTextParts []*Part
accumulatedText := strings.Builder{}

for _, part := range m.Content {
if !part.IsText() {
continue
}

text := base.ExtractJSONFromMarkdown(part.Text)

if j.config.Schema != nil {
var schemaBytes []byte
schemaBytes, err := json.Marshal(j.config.Schema)
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(text), schemaBytes); err != nil {
return nil, err
}
nonTextParts = append(nonTextParts, part)
} else {
if !base.ValidJSON(text) {
return nil, errors.New("message is not a valid JSON")
}
accumulatedText.WriteString(part.Text)
}
}

m.Content[i] = NewJSONPart(text)
text := base.ExtractJSONFromMarkdown(accumulatedText.String())

if j.config.Schema != nil {
var schemaBytes []byte
schemaBytes, err := json.Marshal(j.config.Schema)
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(text), schemaBytes); err != nil {
return nil, err
}
} else {
if !base.ValidJSON(text) {
return nil, errors.New("message is not a valid JSON")
}
}

newParts := []*Part{NewJSONPart(text)}
newParts = append(newParts, nonTextParts...)

m.Content = newParts
}

return m, nil
Expand Down
41 changes: 24 additions & 17 deletions go/ai/format_jsonl.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"

"github.com/firebase/genkit/go/internal/base"
)
Expand Down Expand Up @@ -79,29 +80,35 @@ func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) {
return nil, errors.New("message has no content")
}

var newParts []*Part
var nonTextParts []*Part
accumulatedText := strings.Builder{}

for _, part := range m.Content {
if !part.IsText() {
newParts = append(newParts, part)
nonTextParts = append(nonTextParts, part)
} else {
lines := base.GetJsonObjectLines(part.Text)
for _, line := range lines {
if j.config.Schema != nil {
var schemaBytes []byte
schemaBytes, err := json.Marshal(j.config.Schema["items"])
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
return nil, err
}
}

newParts = append(newParts, NewJSONPart(line))
accumulatedText.WriteString(part.Text)
}
}

var newParts []*Part
lines := base.GetJsonObjectLines(accumulatedText.String())
for _, line := range lines {
if j.config.Schema != nil {
var schemaBytes []byte
schemaBytes, err := json.Marshal(j.config.Schema["items"])
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
return nil, err
}
}

newParts = append(newParts, NewJSONPart(line))
}
m.Content = newParts

m.Content = append(newParts, nonTextParts...)
}

return m, nil
Expand Down
Loading
Loading