Skip to content

Commit 8774e81

Browse files
authored
fix(go): fixed bad dotprompt output format parsing (#4109)
1 parent b6961dd commit 8774e81

File tree

4 files changed

+121
-8
lines changed

4 files changed

+121
-8
lines changed

go/ai/prompt.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,11 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp
776776
}
777777

778778
if inputSchema, ok := metadata.Input.Schema.(map[string]any); ok {
779-
opts.InputSchema = inputSchema
779+
if ref, ok := inputSchema["$ref"].(string); ok {
780+
opts.InputSchema = core.SchemaRef(ref)
781+
} else {
782+
opts.InputSchema = inputSchema
783+
}
780784
}
781785

782786
if metadata.Output.Format != "" {
@@ -794,6 +798,17 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp
794798
}
795799
}
796800

801+
if outputSchema, ok := metadata.Output.Schema.(map[string]any); ok {
802+
if ref, ok := outputSchema["$ref"].(string); ok {
803+
opts.OutputSchema = core.SchemaRef(ref)
804+
} else {
805+
opts.OutputSchema = outputSchema
806+
}
807+
if opts.OutputFormat == "" {
808+
opts.OutputFormat = OutputFormatJSON
809+
}
810+
}
811+
797812
key := promptKey(name, variant, namespace)
798813

799814
prompt := DefinePrompt(r, key, opts, WithPrompt(parsedPrompt.Template))

go/ai/prompt_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,68 @@ Simple prompt
13681368
t.Fatal("Prompt 'simple' was not registered")
13691369
}
13701370
})
1371+
1372+
t.Run("prompt with inline output schema", func(t *testing.T) {
1373+
reg := registry.New()
1374+
ConfigureFormats(reg)
1375+
1376+
source := `---
1377+
model: test/chat
1378+
output:
1379+
format: json
1380+
schema:
1381+
type: object
1382+
properties:
1383+
title:
1384+
type: string
1385+
description:
1386+
type: string
1387+
required:
1388+
- title
1389+
- description
1390+
---
1391+
Generate something
1392+
`
1393+
prompt, err := LoadPromptFromSource(reg, source, "outputSchemaPrompt", "")
1394+
if err != nil {
1395+
t.Fatalf("LoadPromptFromRaw failed: %v", err)
1396+
}
1397+
if prompt == nil {
1398+
t.Fatal("LoadPromptFromRaw returned nil prompt")
1399+
}
1400+
1401+
actionOpts, err := prompt.Render(context.Background(), nil)
1402+
if err != nil {
1403+
t.Fatalf("Render failed: %v", err)
1404+
}
1405+
1406+
// Verify that the output config is set correctly
1407+
if actionOpts.Output == nil {
1408+
t.Fatal("Expected Output config to be set")
1409+
}
1410+
if actionOpts.Output.Format != OutputFormatJSON {
1411+
t.Errorf("Expected output format 'json', got %q", actionOpts.Output.Format)
1412+
}
1413+
if actionOpts.Output.JsonSchema == nil {
1414+
t.Fatal("Expected output JsonSchema to be set for inline schema")
1415+
}
1416+
1417+
// Verify the schema structure
1418+
schema := actionOpts.Output.JsonSchema
1419+
if schema["type"] != "object" {
1420+
t.Errorf("Expected schema type 'object', got %v", schema["type"])
1421+
}
1422+
properties, ok := schema["properties"].(map[string]any)
1423+
if !ok {
1424+
t.Fatal("Expected schema properties to be a map")
1425+
}
1426+
if _, ok := properties["title"]; !ok {
1427+
t.Error("Expected schema to have 'title' property")
1428+
}
1429+
if _, ok := properties["description"]; !ok {
1430+
t.Error("Expected schema to have 'description' property")
1431+
}
1432+
})
13711433
}
13721434

13731435
// TestDefinePartialAndHelperJourney demonstrates a complete user journey for defining

go/internal/base/json.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,29 @@ func SchemaAsMap(s *jsonschema.Schema) map[string]any {
138138
return m
139139
}
140140

141-
// jsonMarkdownRegex specifically looks for "json" language identifier
142-
var jsonMarkdownRegex = regexp.MustCompile("(?s)```json(.*?)```")
141+
// jsonMarkdownRegex matches fenced code blocks with "json" language identifier (case-insensitive).
142+
var jsonMarkdownRegex = regexp.MustCompile("(?si)```json\\s*(.*?)```")
143+
144+
// plainMarkdownRegex matches fenced code blocks without any language identifier.
145+
var plainMarkdownRegex = regexp.MustCompile("(?s)```\\s*\\n(.*?)```")
143146

144147
// ExtractJSONFromMarkdown returns the contents of the first fenced code block in
145-
// the markdown text md. If there is none, it returns md.
148+
// the markdown text md. It matches code blocks with "json" identifier (case-insensitive)
149+
// or code blocks without any language identifier. If there is no matching block, it returns md.
146150
func ExtractJSONFromMarkdown(md string) string {
151+
// First try to match explicit json code blocks
147152
matches := jsonMarkdownRegex.FindStringSubmatch(md)
148-
if len(matches) < 2 {
149-
return md
153+
if len(matches) >= 2 {
154+
return strings.TrimSpace(matches[1])
155+
}
156+
157+
// Fall back to plain code blocks (no language identifier)
158+
matches = plainMarkdownRegex.FindStringSubmatch(md)
159+
if len(matches) >= 2 {
160+
return strings.TrimSpace(matches[1])
150161
}
151-
// capture group 1 matches the actual fenced JSON block
152-
return strings.TrimSpace(matches[1])
162+
163+
return md
153164
}
154165

155166
// GetJSONObjectLines splits a string by newlines, trims whitespace from each line,

go/internal/base/json_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,31 @@ func TestExtractJSONFromMarkdown(t *testing.T) {
7878
in: "```json\n{\"a\": 1}\n``` ```yaml\nkey: 1\nanother-key: 2```",
7979
want: "{\"a\": 1}",
8080
},
81+
{
82+
desc: "uppercase JSON identifier",
83+
in: "```JSON\n{\"a\": 1}\n```",
84+
want: "{\"a\": 1}",
85+
},
86+
{
87+
desc: "mixed case Json identifier",
88+
in: "```Json\n{\"a\": 1}\n```",
89+
want: "{\"a\": 1}",
90+
},
91+
{
92+
desc: "plain code block without identifier",
93+
in: "```\n{\"a\": 1}\n```",
94+
want: "{\"a\": 1}",
95+
},
96+
{
97+
desc: "plain code block with text before",
98+
in: "Here is the result:\n\n```\n{\"title\": \"Pizza\"}\n```",
99+
want: "{\"title\": \"Pizza\"}",
100+
},
101+
{
102+
desc: "json block preferred over plain block",
103+
in: "```\n{\"plain\": true}\n``` then ```json\n{\"json\": true}\n```",
104+
want: "{\"json\": true}",
105+
},
81106
}
82107
for _, tc := range tests {
83108
t.Run(tc.desc, func(t *testing.T) {

0 commit comments

Comments
 (0)