Skip to content

Commit 83ed32f

Browse files
authored
feat(ai): add instance AI providers and transcription (#5829)
Co-authored-by: memoclaw <265580040+memoclaw@users.noreply.github.com>
1 parent 40fd700 commit 83ed32f

33 files changed

+3522
-111
lines changed

internal/ai/ai.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package ai
2+
3+
// ProviderType identifies an AI provider implementation.
4+
type ProviderType string
5+
6+
const (
7+
// ProviderOpenAI is OpenAI's hosted API.
8+
ProviderOpenAI ProviderType = "OPENAI"
9+
// ProviderOpenAICompatible is an OpenAI-compatible API endpoint.
10+
ProviderOpenAICompatible ProviderType = "OPENAI_COMPATIBLE"
11+
// ProviderAnthropic is Anthropic's API.
12+
ProviderAnthropic ProviderType = "ANTHROPIC"
13+
// ProviderGemini is Google's Gemini API.
14+
ProviderGemini ProviderType = "GEMINI"
15+
)
16+
17+
// ProviderConfig configures a callable AI provider connection.
18+
type ProviderConfig struct {
19+
ID string
20+
Title string
21+
Type ProviderType
22+
Endpoint string
23+
APIKey string
24+
Models []string
25+
DefaultModel string
26+
}

internal/ai/errors.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package ai
2+
3+
import "github.com/pkg/errors"
4+
5+
var (
6+
// ErrProviderNotFound indicates that a requested provider ID does not exist.
7+
ErrProviderNotFound = errors.New("AI provider not found")
8+
// ErrCapabilityUnsupported indicates that the provider does not support the requested capability.
9+
ErrCapabilityUnsupported = errors.New("AI provider capability unsupported")
10+
)

internal/ai/openai/client.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package openai
2+
3+
import (
4+
"net/http"
5+
"net/url"
6+
"strings"
7+
"time"
8+
9+
"github.com/pkg/errors"
10+
11+
"github.com/usememos/memos/internal/ai"
12+
)
13+
14+
const defaultEndpoint = "https://api.openai.com/v1"
15+
16+
// Transcriber transcribes audio with OpenAI-compatible transcription APIs.
17+
type Transcriber struct {
18+
endpoint string
19+
apiKey string
20+
httpClient *http.Client
21+
}
22+
23+
// NewTranscriber creates a new OpenAI-compatible transcriber.
24+
func NewTranscriber(config ai.ProviderConfig, options ...Option) (*Transcriber, error) {
25+
endpoint := strings.TrimSpace(config.Endpoint)
26+
if endpoint == "" {
27+
endpoint = defaultEndpoint
28+
}
29+
if _, err := url.ParseRequestURI(endpoint); err != nil {
30+
return nil, errors.Wrap(err, "invalid OpenAI endpoint")
31+
}
32+
if config.APIKey == "" {
33+
return nil, errors.New("OpenAI API key is required")
34+
}
35+
36+
transcriber := &Transcriber{
37+
endpoint: endpoint,
38+
apiKey: config.APIKey,
39+
httpClient: &http.Client{
40+
Timeout: 2 * time.Minute,
41+
},
42+
}
43+
for _, option := range options {
44+
option(transcriber)
45+
}
46+
return transcriber, nil
47+
}
48+
49+
// Option configures a Transcriber.
50+
type Option func(*Transcriber)
51+
52+
// WithHTTPClient sets the HTTP client used by the transcriber.
53+
func WithHTTPClient(client *http.Client) Option {
54+
return func(t *Transcriber) {
55+
if client != nil {
56+
t.httpClient = client
57+
}
58+
}
59+
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"io"
8+
"mime"
9+
"mime/multipart"
10+
"net/http"
11+
"net/textproto"
12+
"strings"
13+
14+
"github.com/pkg/errors"
15+
16+
"github.com/usememos/memos/internal/ai"
17+
)
18+
19+
type transcriptionResponse struct {
20+
Text string `json:"text"`
21+
Language string `json:"language"`
22+
Duration float64 `json:"duration"`
23+
}
24+
25+
type errorResponse struct {
26+
Error struct {
27+
Message string `json:"message"`
28+
Type string `json:"type"`
29+
Code string `json:"code"`
30+
} `json:"error"`
31+
}
32+
33+
// Transcribe transcribes audio with the /audio/transcriptions endpoint.
34+
func (t *Transcriber) Transcribe(ctx context.Context, request ai.TranscribeRequest) (*ai.TranscribeResponse, error) {
35+
if strings.TrimSpace(request.Model) == "" {
36+
return nil, errors.New("model is required")
37+
}
38+
if request.Audio == nil {
39+
return nil, errors.New("audio is required")
40+
}
41+
42+
body := &bytes.Buffer{}
43+
writer := multipart.NewWriter(body)
44+
if err := writeAudioFilePart(writer, request); err != nil {
45+
return nil, err
46+
}
47+
if err := writer.WriteField("model", request.Model); err != nil {
48+
return nil, errors.Wrap(err, "failed to write model field")
49+
}
50+
if err := writer.WriteField("response_format", "json"); err != nil {
51+
return nil, errors.Wrap(err, "failed to write response format field")
52+
}
53+
if request.Prompt != "" {
54+
if err := writer.WriteField("prompt", request.Prompt); err != nil {
55+
return nil, errors.Wrap(err, "failed to write prompt field")
56+
}
57+
}
58+
if request.Language != "" {
59+
if err := writer.WriteField("language", request.Language); err != nil {
60+
return nil, errors.Wrap(err, "failed to write language field")
61+
}
62+
}
63+
if err := writer.Close(); err != nil {
64+
return nil, errors.Wrap(err, "failed to close multipart writer")
65+
}
66+
67+
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(t.endpoint, "/")+"/audio/transcriptions", body)
68+
if err != nil {
69+
return nil, errors.Wrap(err, "failed to create transcription request")
70+
}
71+
httpRequest.Header.Set("Authorization", "Bearer "+t.apiKey)
72+
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
73+
74+
httpResponse, err := t.httpClient.Do(httpRequest)
75+
if err != nil {
76+
return nil, errors.Wrap(err, "failed to send transcription request")
77+
}
78+
defer httpResponse.Body.Close()
79+
80+
responseBody, err := io.ReadAll(httpResponse.Body)
81+
if err != nil {
82+
return nil, errors.Wrap(err, "failed to read transcription response")
83+
}
84+
if httpResponse.StatusCode < http.StatusOK || httpResponse.StatusCode >= http.StatusMultipleChoices {
85+
return nil, errors.Errorf("transcription request failed with status %d: %s", httpResponse.StatusCode, extractErrorMessage(responseBody))
86+
}
87+
88+
var response transcriptionResponse
89+
if err := json.Unmarshal(responseBody, &response); err != nil {
90+
return nil, errors.Wrap(err, "failed to unmarshal transcription response")
91+
}
92+
return &ai.TranscribeResponse{
93+
Text: response.Text,
94+
Language: response.Language,
95+
Duration: response.Duration,
96+
}, nil
97+
}
98+
99+
func writeAudioFilePart(writer *multipart.Writer, request ai.TranscribeRequest) error {
100+
filename := strings.TrimSpace(request.Filename)
101+
if filename == "" {
102+
filename = "audio"
103+
}
104+
contentType := strings.TrimSpace(request.ContentType)
105+
if contentType == "" {
106+
contentType = "application/octet-stream"
107+
} else {
108+
mediaType, _, err := mime.ParseMediaType(contentType)
109+
if err != nil {
110+
return errors.Wrap(err, "invalid audio content type")
111+
}
112+
contentType = mediaType
113+
}
114+
115+
header := make(textproto.MIMEHeader)
116+
header.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{
117+
"name": "file",
118+
"filename": sanitizeFilename(filename),
119+
}))
120+
header.Set("Content-Type", contentType)
121+
part, err := writer.CreatePart(header)
122+
if err != nil {
123+
return errors.Wrap(err, "failed to create audio file part")
124+
}
125+
if _, err := io.Copy(part, request.Audio); err != nil {
126+
return errors.Wrap(err, "failed to write audio file part")
127+
}
128+
return nil
129+
}
130+
131+
func extractErrorMessage(responseBody []byte) string {
132+
var response errorResponse
133+
if err := json.Unmarshal(responseBody, &response); err == nil && response.Error.Message != "" {
134+
return response.Error.Message
135+
}
136+
return string(responseBody)
137+
}
138+
139+
func sanitizeFilename(filename string) string {
140+
filename = strings.NewReplacer("\r", "_", "\n", "_").Replace(filename)
141+
if strings.TrimSpace(filename) == "" {
142+
return "audio"
143+
}
144+
return filename
145+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"strings"
9+
"testing"
10+
"time"
11+
12+
"github.com/stretchr/testify/require"
13+
14+
"github.com/usememos/memos/internal/ai"
15+
)
16+
17+
func TestTranscribe(t *testing.T) {
18+
t.Parallel()
19+
20+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21+
require.Equal(t, http.MethodPost, r.Method)
22+
require.Equal(t, "/audio/transcriptions", r.URL.Path)
23+
require.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
24+
require.NoError(t, r.ParseMultipartForm(10<<20))
25+
require.Equal(t, "gpt-4o-transcribe", r.FormValue("model"))
26+
require.Equal(t, "json", r.FormValue("response_format"))
27+
require.Equal(t, "domain words", r.FormValue("prompt"))
28+
require.Equal(t, "en", r.FormValue("language"))
29+
30+
file, header, err := r.FormFile("file")
31+
require.NoError(t, err)
32+
defer file.Close()
33+
require.Equal(t, "voice.wav", header.Filename)
34+
require.Equal(t, "audio/wav", header.Header.Get("Content-Type"))
35+
36+
w.Header().Set("Content-Type", "application/json")
37+
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
38+
"text": "hello world",
39+
"language": "en",
40+
"duration": 1.5,
41+
}))
42+
}))
43+
defer server.Close()
44+
45+
transcriber, err := NewTranscriber(ai.ProviderConfig{
46+
Endpoint: server.URL,
47+
APIKey: "test-key",
48+
})
49+
require.NoError(t, err)
50+
51+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
52+
defer cancel()
53+
response, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{
54+
Model: "gpt-4o-transcribe",
55+
Filename: "voice.wav",
56+
ContentType: "audio/wav",
57+
Audio: strings.NewReader("RIFF"),
58+
Prompt: "domain words",
59+
Language: "en",
60+
})
61+
require.NoError(t, err)
62+
require.Equal(t, "hello world", response.Text)
63+
require.Equal(t, "en", response.Language)
64+
require.Equal(t, 1.5, response.Duration)
65+
}

internal/ai/resolver.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package ai
2+
3+
import "github.com/pkg/errors"
4+
5+
// FindProvider returns the provider with the given ID.
6+
func FindProvider(providers []ProviderConfig, providerID string) (*ProviderConfig, error) {
7+
if providerID == "" {
8+
return nil, errors.Wrap(ErrProviderNotFound, "provider ID is required")
9+
}
10+
for _, provider := range providers {
11+
if provider.ID == providerID {
12+
return &provider, nil
13+
}
14+
}
15+
return nil, errors.Wrapf(ErrProviderNotFound, "provider ID %q", providerID)
16+
}

internal/ai/transcription.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package ai
2+
3+
import (
4+
"context"
5+
"io"
6+
)
7+
8+
// Transcriber transcribes audio into text.
9+
type Transcriber interface {
10+
Transcribe(ctx context.Context, request TranscribeRequest) (*TranscribeResponse, error)
11+
}
12+
13+
// TranscribeRequest contains an audio transcription request.
14+
type TranscribeRequest struct {
15+
Model string
16+
Filename string
17+
ContentType string
18+
Audio io.Reader
19+
Size int64
20+
Prompt string
21+
Language string
22+
}
23+
24+
// TranscribeResponse contains an audio transcription response.
25+
type TranscribeResponse struct {
26+
Text string
27+
Language string
28+
Duration float64
29+
}

0 commit comments

Comments
 (0)