Skip to content

Commit aae82b1

Browse files
committed
feat: implement OpenAI client for text, image, audio generation and model listing (#10)
- Added OpenAIClient struct to interact with OpenAI's APIs. - Implemented GenerateText method for chat completion. - Implemented GenerateImage method for DALL-E image generation. - Implemented GenerateAudio method for text-to-speech audio generation. - Added ListModels method to retrieve available models from OpenAI. - Included error handling for API responses and required fields.
1 parent 3fd0ea1 commit aae82b1

25 files changed

+1739
-221
lines changed

binding/ai/service.go

Lines changed: 178 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package ai
33
import (
44
"context"
55
"fmt"
6+
7+
"firebringer/database"
8+
aiservice "firebringer/service/ai"
69
)
710

811
// Service provides AI methods for the frontend
@@ -24,37 +27,41 @@ func (s *Service) SetContext(ctx context.Context) {
2427
type TextRequest struct {
2528
Prompt string `json:"prompt"`
2629
Model string `json:"model"`
30+
ProviderID int `json:"providerId"`
2731
Temperature *float64 `json:"temperature,omitempty"`
2832
MaxTokens *int `json:"maxTokens,omitempty"`
2933
Options map[string]interface{} `json:"options,omitempty"`
3034
}
3135

3236
// ImageRequest defines the parameters for image generation
3337
type ImageRequest struct {
34-
Prompt string `json:"prompt"`
35-
Model string `json:"model"`
36-
Size string `json:"size,omitempty"`
37-
Quality string `json:"quality,omitempty"`
38-
Style string `json:"style,omitempty"`
39-
Options map[string]interface{} `json:"options,omitempty"`
38+
Prompt string `json:"prompt"`
39+
Model string `json:"model"`
40+
ProviderID int `json:"providerId"`
41+
Size string `json:"size,omitempty"`
42+
Quality string `json:"quality,omitempty"`
43+
Style string `json:"style,omitempty"`
44+
Options map[string]interface{} `json:"options,omitempty"`
4045
}
4146

4247
// VideoRequest defines the parameters for video generation
4348
type VideoRequest struct {
4449
Prompt string `json:"prompt"`
4550
Model string `json:"model"`
51+
ProviderID int `json:"providerId"`
4652
Duration string `json:"duration,omitempty"`
4753
Resolution string `json:"resolution,omitempty"`
4854
Options map[string]interface{} `json:"options,omitempty"`
4955
}
5056

5157
// AudioRequest defines the parameters for audio generation
5258
type AudioRequest struct {
53-
Prompt string `json:"prompt"`
54-
Model string `json:"model"`
55-
Voice string `json:"voice,omitempty"`
56-
Speed *float64 `json:"speed,omitempty"`
57-
Options map[string]interface{} `json:"options,omitempty"`
59+
Prompt string `json:"prompt"`
60+
Model string `json:"model"`
61+
ProviderID int `json:"providerId"`
62+
Voice string `json:"voice,omitempty"`
63+
Speed *float64 `json:"speed,omitempty"`
64+
Options map[string]interface{} `json:"options,omitempty"`
5865
}
5966

6067
// AIResponse defines the common response structure for AI requests
@@ -64,34 +71,186 @@ type AIResponse struct {
6471
Raw interface{} `json:"raw,omitempty"`
6572
}
6673

74+
func (s *Service) getClient(providerID int) (aiservice.AIClient, error) {
75+
config, err := database.GetModelProvider(providerID)
76+
if err != nil {
77+
return nil, fmt.Errorf("failed to get config for provider id %d: %w", providerID, err)
78+
}
79+
if config == nil {
80+
return nil, fmt.Errorf("no configuration found for provider id %d. Please configure it in settings", providerID)
81+
}
82+
83+
return aiservice.NewClient(*config)
84+
}
85+
6786
// GenerateText generates text based on the prompt
6887
func (s *Service) GenerateText(req TextRequest) (*AIResponse, error) {
69-
// TODO: Implement actual AI call
88+
client, err := s.getClient(req.ProviderID)
89+
if err != nil {
90+
return nil, err
91+
}
92+
93+
aiReq := aiservice.TextGenerateRequest{
94+
Prompt: req.Prompt,
95+
Model: req.Model,
96+
Temperature: req.Temperature,
97+
MaxTokens: req.MaxTokens,
98+
Options: req.Options,
99+
}
100+
101+
resp, err := client.GenerateText(s.ctx, aiReq)
102+
if err != nil {
103+
return nil, err
104+
}
105+
70106
return &AIResponse{
71-
Content: fmt.Sprintf("Generated text for prompt: %s using model: %s", req.Prompt, req.Model),
107+
Content: resp.Content,
108+
Usage: map[string]interface{}{
109+
"promptTokens": resp.PromptTokens,
110+
"outputTokens": resp.OutputTokens,
111+
"totalTokens": resp.TotalTokens,
112+
},
113+
Raw: resp,
72114
}, nil
73115
}
74116

75117
// GenerateImage generates an image based on the prompt
76118
func (s *Service) GenerateImage(req ImageRequest) (*AIResponse, error) {
77-
// TODO: Implement actual AI call
119+
client, err := s.getClient(req.ProviderID)
120+
if err != nil {
121+
return nil, err
122+
}
123+
124+
aiReq := aiservice.ImageGenerateRequest{
125+
Prompt: req.Prompt,
126+
Model: req.Model,
127+
Size: req.Size,
128+
Quality: req.Quality,
129+
Style: req.Style,
130+
Options: req.Options,
131+
}
132+
133+
resp, err := client.GenerateImage(s.ctx, aiReq)
134+
if err != nil {
135+
return nil, err
136+
}
137+
138+
content := resp.URL
139+
if content == "" {
140+
content = resp.B64JSON
141+
}
142+
78143
return &AIResponse{
79-
Content: fmt.Sprintf("Generated image URL/data for prompt: %s using model: %s", req.Prompt, req.Model),
144+
Content: content,
145+
Raw: resp,
80146
}, nil
81147
}
82148

83149
// GenerateVideo generates a video based on the prompt
84150
func (s *Service) GenerateVideo(req VideoRequest) (*AIResponse, error) {
85-
// TODO: Implement actual AI call
151+
client, err := s.getClient(req.ProviderID)
152+
if err != nil {
153+
return nil, err
154+
}
155+
156+
aiReq := aiservice.VideoGenerateRequest{
157+
Prompt: req.Prompt,
158+
Model: req.Model,
159+
Duration: req.Duration,
160+
Resolution: req.Resolution,
161+
Options: req.Options,
162+
}
163+
164+
resp, err := client.GenerateVideo(s.ctx, aiReq)
165+
if err != nil {
166+
return nil, err
167+
}
168+
169+
content := resp.URL
170+
if content == "" && len(resp.Data) > 0 {
171+
content = "Video data (base64 or bytes)" // Improve this for frontend to handle bytes
172+
// potentially convert bytes to base64 if it is bytes
173+
}
174+
86175
return &AIResponse{
87-
Content: fmt.Sprintf("Generated video URL/data for prompt: %s using model: %s", req.Prompt, req.Model),
176+
Content: content,
177+
Raw: resp,
88178
}, nil
89179
}
90180

91181
// GenerateAudio generates audio based on the prompt
92182
func (s *Service) GenerateAudio(req AudioRequest) (*AIResponse, error) {
93-
// TODO: Implement actual AI call
183+
client, err := s.getClient(req.ProviderID)
184+
if err != nil {
185+
return nil, err
186+
}
187+
188+
aiReq := aiservice.AudioGenerateRequest{
189+
Prompt: req.Prompt,
190+
Model: req.Model,
191+
Voice: req.Voice,
192+
Speed: req.Speed,
193+
Options: req.Options,
194+
}
195+
196+
resp, err := client.GenerateAudio(s.ctx, aiReq)
197+
if err != nil {
198+
return nil, err
199+
}
200+
201+
// Usually audio is returned as bytes.
202+
// We might want to base64 encode it for the frontend or return a Blob URL if we could.
203+
// For now, let's assume valid JSON marshalling or handle it in specific response type
204+
content := ""
205+
if len(resp.Data) > 0 {
206+
// Simple indicator, actual data in Raw or handled by frontend from specific field?
207+
// Actually AIResponse.Raw is interface{}, so it will marshal the []byte as base64 string automatically in JSON.
208+
content = "Audio generated"
209+
}
210+
94211
return &AIResponse{
95-
Content: fmt.Sprintf("Generated audio URL/data for prompt: %s using model: %s", req.Prompt, req.Model),
212+
Content: content,
213+
Raw: resp,
96214
}, nil
97215
}
216+
217+
// ListModels lists available models for a given provider ID. If providerId is nil, lists from all providers.
218+
func (s *Service) ListModels(providerId *int) ([]aiservice.Model, error) {
219+
if providerId == nil {
220+
configs, err := database.ListModelProviders()
221+
if err != nil {
222+
return nil, fmt.Errorf("failed to list providers: %w", err)
223+
}
224+
225+
var allModels []aiservice.Model
226+
for _, config := range configs {
227+
client, err := aiservice.NewClient(config)
228+
if err != nil {
229+
fmt.Printf("failed to create client for %s: %v\n", config.Name, err)
230+
continue
231+
}
232+
models, err := client.ListModels(s.ctx)
233+
if err != nil {
234+
fmt.Printf("failed to list models for %s: %v\n", config.Name, err)
235+
continue
236+
}
237+
allModels = append(allModels, models...)
238+
}
239+
return allModels, nil
240+
}
241+
242+
config, err := database.GetModelProvider(*providerId)
243+
if err != nil {
244+
return nil, fmt.Errorf("failed to get config for provider id %d: %w", *providerId, err)
245+
}
246+
if config == nil {
247+
return nil, fmt.Errorf("no configuration found for provider id %d", *providerId)
248+
}
249+
250+
client, err := aiservice.NewClient(*config)
251+
if err != nil {
252+
return nil, fmt.Errorf("failed to create client for provider id %d: %w", *providerId, err)
253+
}
254+
255+
return client.ListModels(s.ctx)
256+
}

binding/database/service.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,28 @@ import (
44
db "firebringer/database"
55
)
66

7-
// Service provides database methods for the frontend
87
type Service struct{}
98

10-
// NewService creates a new Database Service
119
func NewService() *Service {
1210
return &Service{}
1311
}
1412

15-
// GetAIConfig returns the configuration for a specific provider
16-
func (s *Service) GetAIConfig(provider string) (*db.AIConfig, error) {
17-
return db.GetAIConfig(db.AIProvider(provider))
13+
// GetModelProvider retrieves a model provider configuration by ID
14+
func (s *Service) GetModelProvider(id int) (*db.ModelProvider, error) {
15+
return db.GetModelProvider(id)
1816
}
1917

20-
// SaveAIConfig saves the configuration for a specific provider
21-
func (s *Service) SaveAIConfig(config db.AIConfig) error {
22-
return db.SaveAIConfig(&config)
18+
// SaveModelProvider saves or updates a model provider configuration
19+
func (s *Service) SaveModelProvider(config db.ModelProvider) error {
20+
return db.SaveModelProvider(config)
2321
}
2422

25-
// DeleteAIConfig deletes the configuration for a specific provider
26-
func (s *Service) DeleteAIConfig(provider string) error {
27-
return db.DeleteAIConfig(db.AIProvider(provider))
23+
// DeleteModelProvider deletes a model provider configuration
24+
func (s *Service) DeleteModelProvider(id int) error {
25+
return db.DeleteModelProvider(id)
2826
}
2927

30-
// ListAIConfigs returns all AI configurations
31-
func (s *Service) ListAIConfigs() ([]db.AIConfig, error) {
32-
return db.ListAIConfigs()
28+
// ListModelProviders lists all model provider configurations
29+
func (s *Service) ListModelProviders() ([]db.ModelProvider, error) {
30+
return db.ListModelProviders()
3331
}

database/db.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ func InitDB() error {
2424
}
2525

2626
schema := `
27-
CREATE TABLE IF NOT EXISTS ai_configs (
28-
provider TEXT PRIMARY KEY,
27+
CREATE TABLE IF NOT EXISTS model_providers (
28+
id INTEGER PRIMARY KEY AUTOINCREMENT,
29+
name TEXT NOT NULL,
30+
type TEXT NOT NULL,
2931
api_key TEXT NOT NULL,
3032
base_url TEXT DEFAULT '',
3133
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,

database/models.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ const (
1010
ProviderClaude AIProvider = "claude"
1111
)
1212

13-
type AIConfig struct {
14-
Provider AIProvider `db:"provider" json:"provider"`
13+
// ModelProvider represents an AI model provider configuration
14+
type ModelProvider struct {
15+
ID int `db:"id" json:"id"`
16+
Name string `db:"name" json:"name"`
17+
Type AIProvider `db:"type" json:"type"`
1518
APIKey string `db:"api_key" json:"apiKey"`
1619
BaseURL string `db:"base_url" json:"baseUrl"`
1720
CreatedAt time.Time `db:"created_at" json:"createdAt"`

0 commit comments

Comments
 (0)