Skip to content

Commit 2787d3c

Browse files
authored
feat(gemini): show grouding google search (#753)
* feat(gemini): show grouding google search * fix(gemini): extra new line for grounding metadata of OpenAI stream choice
1 parent 71e6db8 commit 2787d3c

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

providers/gemini/type.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ func (candidate *GeminiChatCandidate) ToOpenAIStreamChoice(request *types.ChatCo
150150
choice.Delta.Image = images
151151
}
152152

153+
// Add grounding metadata as markdown citations
154+
if candidate.GroundingMetadata != nil && showGoogleSearchMeta(request) {
155+
groundingMarkdown := formatGroundingMetadataAsMarkdown(candidate.GroundingMetadata)
156+
if groundingMarkdown != "" {
157+
content = append(content, "\n\n" + groundingMarkdown)
158+
}
159+
}
160+
153161
choice.Delta.Content = strings.Join(content, "\n")
154162

155163
if len(reasoningContent) > 0 {
@@ -230,6 +238,18 @@ func (candidate *GeminiChatCandidate) ToOpenAIChoice(request *types.ChatCompleti
230238

231239
choice.Message.Content = strings.Join(content, "\n")
232240

241+
// Add grounding metadata as markdown citations
242+
if candidate.GroundingMetadata != nil && showGoogleSearchMeta(request) {
243+
groundingMarkdown := formatGroundingMetadataAsMarkdown(candidate.GroundingMetadata)
244+
if groundingMarkdown != "" {
245+
if contentStr, ok := choice.Message.Content.(string); ok && contentStr != "" {
246+
choice.Message.Content = contentStr + "\n\n" + groundingMarkdown
247+
} else {
248+
choice.Message.Content = groundingMarkdown
249+
}
250+
}
251+
}
252+
233253
if len(reasoningContent) > 0 {
234254
choice.Message.ReasoningContent = strings.Join(reasoningContent, "\n")
235255
}
@@ -365,6 +385,7 @@ type GeminiChatCandidate struct {
365385
CitationMetadata any `json:"citationMetadata,omitempty"`
366386
TokenCount int `json:"tokenCount,omitempty"`
367387
GroundingAttributions []any `json:"groundingAttributions,omitempty"`
388+
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
368389
AvgLogprobs any `json:"avgLogprobs,omitempty"`
369390
}
370391

@@ -621,3 +642,67 @@ func isEmptyOrOnlyNewlines(s string) bool {
621642
trimmed := strings.TrimSpace(s)
622643
return trimmed == ""
623644
}
645+
646+
type GeminiGroundingMetadata struct {
647+
GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"`
648+
WebSearchQueries []string `json:"webSearchQueries,omitempty"`
649+
}
650+
651+
type GeminiGroundingChunk struct {
652+
Web *GeminiGroundingChunkWeb `json:"web,omitempty"`
653+
}
654+
655+
type GeminiGroundingChunkWeb struct {
656+
Uri string `json:"uri,omitempty"`
657+
Title string `json:"title,omitempty"`
658+
}
659+
660+
// checks if googleSearch tool has "show" parameter
661+
func showGoogleSearchMeta(request *types.ChatCompletionRequest) bool {
662+
functions := request.GetFunctions()
663+
if functions == nil {
664+
return false
665+
}
666+
667+
for _, function := range functions {
668+
if function.Name == "googleSearch" && function.Parameters != nil {
669+
if paramStr, ok := function.Parameters.(string); ok && paramStr == "show" {
670+
return true
671+
}
672+
}
673+
}
674+
675+
return false
676+
}
677+
678+
// formats grounding metadata as markdown citation
679+
func formatGroundingMetadataAsMarkdown(metadata *GeminiGroundingMetadata) string {
680+
if metadata == nil || len(metadata.GroundingChunks) == 0 {
681+
return ""
682+
}
683+
var result strings.Builder
684+
// Add search queries
685+
if len(metadata.WebSearchQueries) > 0 {
686+
result.WriteString("> Searched ")
687+
for i, query := range metadata.WebSearchQueries {
688+
if i > 0 {
689+
result.WriteString(" and ")
690+
}
691+
result.WriteString(fmt.Sprintf(`"%s"`, query))
692+
}
693+
result.WriteString("\n")
694+
}
695+
// Add grounding chunks as numbered list
696+
linkCount := 0
697+
for _, chunk := range metadata.GroundingChunks {
698+
if chunk.Web != nil && chunk.Web.Uri != "" {
699+
linkCount++
700+
title := chunk.Web.Title
701+
if title == "" {
702+
title = chunk.Web.Uri
703+
}
704+
result.WriteString(fmt.Sprintf("> %d. [%s](%s)\n", linkCount, title, chunk.Web.Uri))
705+
}
706+
}
707+
return result.String()
708+
}

0 commit comments

Comments
 (0)