Commit b2d71da2 authored by ianshaw's avatar ianshaw
Browse files

feat(backend): 实现 Gemini AI Studio OAuth 和消息兼容服务

- gemini_oauth_service.go: 新增 AI Studio OAuth 类型支持
- gemini_token_provider.go: Token 提供器增强
- gemini_messages_compat_service.go: 支持 AI Studio 端点
- account_test_service.go: Gemini 账户可用性检测
- gateway_service.go: 网关服务适配
- openai_gateway_service.go: OpenAI 兼容层调整
parent 2d6e1d26
......@@ -3,6 +3,7 @@ package service
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
......@@ -16,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
......@@ -38,19 +40,30 @@ type TestEvent struct {
// AccountTestService handles account testing operations
type AccountTestService struct {
accountRepo AccountRepository
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
httpUpstream HTTPUpstream
accountRepo AccountRepository
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService
geminiTokenProvider *GeminiTokenProvider
httpUpstream HTTPUpstream
}
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(accountRepo AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream HTTPUpstream) *AccountTestService {
func NewAccountTestService(
accountRepo AccountRepository,
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
geminiTokenProvider *GeminiTokenProvider,
httpUpstream HTTPUpstream,
) *AccountTestService {
return &AccountTestService{
accountRepo: accountRepo,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
httpUpstream: httpUpstream,
accountRepo: accountRepo,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
geminiTokenProvider: geminiTokenProvider,
httpUpstream: httpUpstream,
}
}
......@@ -123,6 +136,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.testOpenAIAccountConnection(c, account, modelID)
}
if account.IsGemini() {
return s.testGeminiAccountConnection(c, account, modelID)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
......@@ -368,6 +385,247 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return s.processOpenAIStream(c, resp.Body)
}
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
ctx := c.Request.Context()
// Determine the model to use
testModelID := modelID
if testModelID == "" {
testModelID = geminicli.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == model.AccountTypeApiKey {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
}
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create test payload (Gemini format)
payload := createGeminiTestPayload()
// Build request based on account type
var req *http.Request
var err error
switch account.Type {
case model.AccountTypeApiKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case model.AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build request: %s", err.Error()))
}
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
// Get proxy and execute request
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
// Process SSE stream
return s.processGeminiStream(c, resp.Body)
}
// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts
func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, fmt.Errorf("No API key available")
}
baseURL := account.GetCredential("base_url")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
// Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
strings.TrimRight(baseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-goog-api-key", apiKey)
return req, nil
}
// buildGeminiOAuthRequest builds request for Gemini OAuth accounts
func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
if s.geminiTokenProvider == nil {
return nil, fmt.Errorf("Gemini token provider not configured")
}
// Get access token (auto-refreshes if needed)
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("Failed to get access token: %w", err)
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
// AI Studio OAuth mode (no project_id): call generativelanguage API directly with Bearer token.
baseURL := account.GetCredential("base_url")
if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
return req, nil
}
// Wrap payload in Code Assist format
var inner map[string]any
if err := json.Unmarshal(payload, &inner); err != nil {
return nil, err
}
wrapped := map[string]any{
"model": modelID,
"project": projectID,
"request": inner,
}
wrappedBytes, _ := json.Marshal(wrapped)
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return req, nil
}
// createGeminiTestPayload creates a minimal test payload for Gemini API
func createGeminiTestPayload() []byte {
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": "hi"},
},
},
},
"systemInstruction": map[string]any{
"parts": []map[string]any{
{"text": "You are a helpful AI assistant."},
},
},
}
bytes, _ := json.Marshal(payload)
return bytes
}
// processGeminiStream processes SSE stream from Gemini API
func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimPrefix(line, "data: ")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
var data map[string]any
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
continue
}
// Extract text from candidates[0].content.parts[].text
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
if candidate, ok := candidates[0].(map[string]any); ok {
// Check for completion
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// Extract content
if content, ok := candidate["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok && text != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: text})
}
}
}
}
}
}
}
// Handle errors
if errData, ok := data["error"].(map[string]any); ok {
errorMsg := "Unknown error"
if msg, ok := errData["message"].(string); ok {
errorMsg = msg
}
return s.sendErrorAndEnd(c, errorMsg)
}
}
}
// createOpenAITestPayload creates a test payload for OpenAI Responses API
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
payload := map[string]any{
......
......@@ -317,8 +317,17 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
selected = acc
} else if acc.Priority == selected.Priority {
// 优先级相同时,选最久未用的
if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
......
......@@ -19,6 +19,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
......@@ -57,6 +58,11 @@ func NewGeminiMessagesCompatService(
}
}
// GetTokenProvider returns the token provider for OAuth accounts
func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
return s.tokenProvider
}
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
......@@ -94,8 +100,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
......@@ -114,6 +132,96 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
return selected, nil
}
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
//
// Preference order:
// 1) API key accounts (AI Studio)
// 2) OAuth accounts without project_id (AI Studio OAuth)
// 3) OAuth accounts explicitly marked as ai_studio
// 4) Any remaining Gemini accounts (fallback)
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*model.Account, error) {
var accounts []model.Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
if len(accounts) == 0 {
return nil, errors.New("no available Gemini accounts")
}
rank := func(a *model.Account) int {
if a == nil {
return 999
}
switch a.Type {
case model.AccountTypeApiKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
return 9
case model.AccountTypeOAuth:
if strings.TrimSpace(a.GetCredential("project_id")) == "" {
return 1
}
if strings.TrimSpace(a.GetCredential("oauth_type")) == "ai_studio" {
return 2
}
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
return 3
default:
return 10
}
}
var selected *model.Account
for i := range accounts {
acc := &accounts[i]
if selected == nil {
selected = acc
continue
}
r1, r2 := rank(acc), rank(selected)
if r1 < r2 {
selected = acc
continue
}
if r1 > r2 {
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
return nil, errors.New("no available Gemini accounts")
}
return selected, nil
}
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
......@@ -146,6 +254,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
useUpstreamStream := req.Stream
if account.Type == model.AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
}
switch account.Type {
case model.AccountTypeApiKey:
......@@ -190,38 +303,61 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, "", errors.New("missing project_id in account credentials")
}
action := "generateContent"
if req.Stream {
if useUpstreamStream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
if req.Stream {
fullURL += "?alt=sse"
}
wrapped := map[string]any{
"model": mappedModel,
"project": projectID,
}
var inner any
if err := json.Unmarshal(geminiReq, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrapped["request"] = inner
wrappedBytes, _ := json.Marshal(wrapped)
// Two modes for OAuth:
// 1. With project_id -> Code Assist API (wrapped request)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
wrapped := map[string]any{
"model": mappedModel,
"project": projectID,
}
var inner any
if err := json.Unmarshal(geminiReq, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrapped["request"] = inner
wrappedBytes, _ := json.Marshal(wrapped)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
......@@ -301,17 +437,315 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, true)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
}
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel)
c.JSON(http.StatusOK, claudeResp)
usage = usageObj2
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
usage = usageObj
}
} else {
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
if err != nil {
return nil, err
}
}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *model.Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
if strings.TrimSpace(originalModel) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
}
if strings.TrimSpace(action) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
}
if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
// ok
default:
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
mappedModel := originalModel
if account.Type == model.AccountTypeApiKey {
mappedModel = account.GetMappedModel(originalModel)
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
useUpstreamStream := stream
upstreamAction := action
if account.Type == model.AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
upstreamAction = "streamGenerateContent"
}
forceAIStudio := action == "countTokens"
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case model.AccountTypeApiKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, "", errors.New("Gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("x-goog-api-key", apiKey)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
case model.AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("Gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
// Two modes for OAuth:
// 1. With project_id -> Code Assist API (wrapped request)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
wrapped := map[string]any{
"model": mappedModel,
"project": projectID,
}
var inner any
if err := json.Unmarshal(body, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrapped["request"] = inner
wrappedBytes, _ := json.Marshal(wrapped)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
}
requestIDHeader = "x-request-id"
default:
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
}
var resp *http.Response
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
// Local build error: don't retry.
if strings.Contains(err.Error(), "missing project_id") {
return nil, s.writeGoogleError(c, http.StatusBadRequest, err.Error())
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, err.Error())
}
requestIDHeader = idHeader
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if resp.StatusCode == 429 {
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
break
}
defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get(requestIDHeader)
if requestID == "" {
requestID = resp.Header.Get("x-goog-request-id")
}
if requestID != "" {
c.Header("x-request-id", requestID)
}
isOAuth := account.Type == model.AccountTypeOAuth
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: requestID,
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
respBody = unwrapIfNeeded(isOAuth, respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
}
var usage *ClaudeUsage
var firstTokenMs *int
if stream {
streamRes, err := s.handleNativeStreamingResponse(c, resp, startTime, isOAuth)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, isOAuth)
if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to read upstream stream")
}
b, _ := json.Marshal(collected)
c.Data(http.StatusOK, "application/json", b)
usage = usageObj
} else {
usageResp, err := s.handleNativeNonStreamingResponse(c, resp, isOAuth)
if err != nil {
return nil, err
}
usage = usageResp
}
}
if usage == nil {
usage = &ClaudeUsage{}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: req.Stream,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
......@@ -590,22 +1024,29 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
openBlockIndex := -1
openBlockType := ""
seenText := ""
openToolIndex := -1
openToolID := ""
openToolName := ""
seenToolJSON := ""
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if errors.Is(err, io.EOF) {
break
}
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("stream read error: %w", err)
}
if !strings.HasPrefix(line, "data:") {
if errors.Is(err, io.EOF) {
break
}
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" || payload == "[DONE]" {
if errors.Is(err, io.EOF) {
break
}
continue
}
......@@ -670,7 +1111,7 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
name = "tool"
}
// Close any open block before tool_use.
// Close any open text block before tool_use.
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
......@@ -680,40 +1121,63 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
openBlockType = ""
}
toolID := "toolu_" + randomHex(8)
toolIndex := nextBlockIndex
nextBlockIndex++
sawToolUse = true
writeSSE(c.Writer, "content_block_start", map[string]any{
"type": "content_block_start",
"index": toolIndex,
"content_block": map[string]any{
"type": "tool_use",
"id": toolID,
"name": name,
"input": map[string]any{},
},
})
// If we receive streamed tool args in pieces, keep a single tool block open and emit deltas.
if openToolIndex >= 0 && openToolName != name {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openToolIndex,
})
openToolIndex = -1
openToolID = ""
openToolName = ""
seenToolJSON = ""
}
if openToolIndex < 0 {
openToolID = "toolu_" + randomHex(8)
openToolIndex = nextBlockIndex
openToolName = name
nextBlockIndex++
sawToolUse = true
argsJSON := "{}"
if args != nil {
if b, err := json.Marshal(args); err == nil {
argsJSON = string(b)
writeSSE(c.Writer, "content_block_start", map[string]any{
"type": "content_block_start",
"index": openToolIndex,
"content_block": map[string]any{
"type": "tool_use",
"id": openToolID,
"name": name,
"input": map[string]any{},
},
})
}
argsJSONText := "{}"
switch v := args.(type) {
case nil:
// keep default "{}"
case string:
if strings.TrimSpace(v) != "" {
argsJSONText = v
}
default:
if b, err := json.Marshal(args); err == nil && len(b) > 0 {
argsJSONText = string(b)
}
}
writeSSE(c.Writer, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": toolIndex,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": argsJSON,
},
})
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": toolIndex,
})
delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText)
seenToolJSON = newSeen
if delta != "" {
writeSSE(c.Writer, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": openToolIndex,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": delta,
},
})
}
flusher.Flush()
}
}
......@@ -721,6 +1185,11 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
if u := extractGeminiUsage(geminiResp); u != nil {
usage = *u
}
// Process the final unterminated line at EOF as well.
if errors.Is(err, io.EOF) {
break
}
}
if openBlockIndex >= 0 {
......@@ -729,6 +1198,12 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
"index": openBlockIndex,
})
}
if openToolIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openToolIndex,
})
}
stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason)
if sawToolUse {
......@@ -779,6 +1254,369 @@ func (s *GeminiMessagesCompatService) writeClaudeError(c *gin.Context, status in
return fmt.Errorf("%s", message)
}
func (s *GeminiMessagesCompatService) writeGoogleError(c *gin.Context, status int, message string) error {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
return fmt.Errorf("%s", message)
}
func unwrapIfNeeded(isOAuth bool, raw []byte) []byte {
if !isOAuth {
return raw
}
inner, err := unwrapGeminiResponse(raw)
if err != nil {
return raw
}
b, err := json.Marshal(inner)
if err != nil {
return raw
}
return b
}
func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) {
reader := bufio.NewReader(body)
var last map[string]any
var lastWithParts map[string]any
usage := &ClaudeUsage{}
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
switch payload {
case "", "[DONE]":
if payload == "[DONE]" {
return pickGeminiCollectResult(last, lastWithParts), usage, nil
}
default:
var parsed map[string]any
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
}
if parsed != nil {
last = parsed
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
}
}
}
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, nil, err
}
}
return pickGeminiCollectResult(last, lastWithParts), usage, nil
}
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
if lastWithParts != nil {
return lastWithParts
}
if last != nil {
return last
}
return map[string]any{}
}
type geminiNativeStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func isGeminiInsufficientScope(headers http.Header, body []byte) bool {
if strings.Contains(strings.ToLower(headers.Get("Www-Authenticate")), "insufficient_scope") {
return true
}
lower := strings.ToLower(string(body))
return strings.Contains(lower, "insufficient authentication scopes") || strings.Contains(lower, "access_token_scope_insufficient")
}
func estimateGeminiCountTokens(reqBody []byte) int {
var obj map[string]any
if err := json.Unmarshal(reqBody, &obj); err != nil {
return 0
}
var texts []string
// systemInstruction.parts[].text
if si, ok := obj["systemInstruction"].(map[string]any); ok {
if parts, ok := si["parts"].([]any); ok {
for _, p := range parts {
if pm, ok := p.(map[string]any); ok {
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
}
}
// contents[].parts[].text
if contents, ok := obj["contents"].([]any); ok {
for _, c := range contents {
cm, ok := c.(map[string]any)
if !ok {
continue
}
parts, ok := cm["parts"].([]any)
if !ok {
continue
}
for _, p := range parts {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
}
total := 0
for _, t := range texts {
total += estimateTokensForText(t)
}
if total < 0 {
return 0
}
return total
}
func estimateTokensForText(s string) int {
s = strings.TrimSpace(s)
if s == "" {
return 0
}
runes := []rune(s)
if len(runes) == 0 {
return 0
}
ascii := 0
for _, r := range runes {
if r <= 0x7f {
ascii++
}
}
asciiRatio := float64(ascii) / float64(len(runes))
if asciiRatio >= 0.8 {
// Roughly 4 chars per token for English-like text.
return (len(runes) + 3) / 4
}
// For CJK-heavy text, approximate 1 rune per token.
return len(runes)
}
type UpstreamHTTPResult struct {
StatusCode int
Headers http.Header
Body []byte
}
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var parsed map[string]any
if isOAuth {
parsed, err = unwrapGeminiResponse(respBody)
if err == nil && parsed != nil {
respBody, _ = json.Marshal(parsed)
}
} else {
_ = json.Unmarshal(respBody, &parsed)
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
return u, nil
}
}
return &ClaudeUsage{}, nil
}
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream; charset=utf-8"
}
c.Header("Content-Type", contentType)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
reader := bufio.NewReader(resp.Body)
usage := &ClaudeUsage{}
var firstTokenMs *int
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
// Keepalive / done markers
if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
} else {
var rawToWrite string
rawToWrite = payload
var parsed map[string]any
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
if b, err := json.Marshal(inner); err == nil {
rawToWrite = string(b)
}
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
}
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if isOAuth {
// SSE format requires double newline (\n\n) to separate events
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", rawToWrite)
} else {
// Pass-through for AI Studio responses.
_, _ = io.WriteString(c.Writer, line)
}
flusher.Flush()
}
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
}
return &geminiNativeStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// ForwardAIStudioGET forwards a GET request to AI Studio (generativelanguage.googleapis.com) for
// endpoints like /v1beta/models and /v1beta/models/{model}.
//
// This is used to support Gemini SDKs that call models listing endpoints before generation.
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *model.Account, path string) (*UpstreamHTTPResult, error) {
if account == nil {
return nil, errors.New("account is nil")
}
path = strings.TrimSpace(path)
if path == "" || !strings.HasPrefix(path, "/") {
return nil, errors.New("invalid path")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := strings.TrimRight(baseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
return nil, err
}
switch account.Type {
case model.AccountTypeApiKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("Gemini api_key not configured")
}
req.Header.Set("x-goog-api-key", apiKey)
case model.AccountTypeOAuth:
if s.tokenProvider == nil {
return nil, errors.New("Gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
resp, err := s.httpUpstream.Do(req, proxyURL)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(),
Body: body,
}, nil
}
func unwrapGeminiResponse(raw []byte) (map[string]any, error) {
var outer map[string]any
if err := json.Unmarshal(raw, &outer); err != nil {
......
......@@ -2,8 +2,12 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
......@@ -43,7 +47,7 @@ type GeminiAuthURLResult struct {
State string `json:"state"`
}
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*GeminiAuthURLResult, error) {
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
state, err := geminicli.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
......@@ -66,22 +70,38 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
}
}
// 两种 OAuth 模式都使用相同的配置,只是 scopes 不同
// scopes 会在 EffectiveOAuthConfig 中根据 oauthType 自动选择
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
session := &geminicli.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
RedirectURI: redirectURI,
ProjectID: strings.TrimSpace(projectID),
OAuthType: oauthType,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
effectiveCfg, err := geminicli.EffectiveOAuthConfig(oauthCfg, oauthType)
if err != nil {
return nil, err
}
authURL, err := geminicli.BuildAuthorizationURL(oauthCfg, state, codeChallenge, redirectURI)
// For Code Assist with Gemini CLI credentials, use the CLI's redirect URI
if oauthType == "code_assist" {
redirectURI = geminicli.GeminiCLIRedirectURI
session.RedirectURI = redirectURI
s.sessionStore.Set(sessionID, session)
}
authURL, err := geminicli.BuildAuthorizationURL(effectiveCfg, state, codeChallenge, redirectURI, session.ProjectID, oauthType)
if err != nil {
return nil, err
}
......@@ -94,11 +114,11 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
}
type GeminiExchangeCodeInput struct {
SessionID string
State string
Code string
RedirectURI string
ProxyID *int64
SessionID string
State string
Code string
ProxyID *int64
OAuthType string // "code_assist" 或 "ai_studio"
}
type GeminiTokenInfo struct {
......@@ -109,6 +129,7 @@ type GeminiTokenInfo struct {
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
......@@ -129,19 +150,38 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
}
redirectURI := session.RedirectURI
if strings.TrimSpace(input.RedirectURI) != "" {
redirectURI = input.RedirectURI
}
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
sessionProjectID := strings.TrimSpace(session.ProjectID)
oauthType := session.OAuthType
if oauthType == "" {
oauthType = "code_assist" // 默认为 code_assist 以兼容旧 session
}
s.sessionStore.Delete(input.SessionID)
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
projectID, _ := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
projectID := sessionProjectID
// 对于 code_assist 模式,project_id 是必需的
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
if oauthType == "code_assist" {
if projectID == "" {
var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
}
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
......@@ -151,6 +191,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
OAuthType: oauthType,
}, nil
}
......@@ -223,7 +264,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *m
}
}
return s.RefreshToken(ctx, refreshToken, proxyURL)
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
// Preserve oauth_type from the account (defaults to code_assist for backward compatibility).
oauthType := strings.TrimSpace(account.GetCredential("oauth_type"))
if oauthType == "" {
oauthType = "code_assist"
}
tokenInfo.OAuthType = oauthType
// Preserve account's project_id when present.
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any {
......@@ -243,6 +316,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType
}
return creds
}
......@@ -255,20 +331,28 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
return "", errors.New("code assist client not configured")
}
loadResp, err := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if err == nil && strings.TrimSpace(loadResp.CurrentTier) != "" && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
}
// pick default tier from allowedTiers, fallback to LEGACY.
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
tierID := "LEGACY"
if loadResp != nil {
for _, tier := range loadResp.AllowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = tier.ID
tierID = strings.TrimSpace(tier.ID)
break
}
}
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
}
req := &geminicli.OnboardUserRequest{
......@@ -284,24 +368,116 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
for attempt := 1; attempt <= maxAttempts; attempt++ {
resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req)
if err != nil {
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
return "", err
}
if resp.Done {
if resp.Response == nil || resp.Response.CloudAICompanionProject == nil {
return "", errors.New("onboardUser completed but no project_id returned")
}
switch v := resp.Response.CloudAICompanionProject.(type) {
case string:
return strings.TrimSpace(v), nil
case map[string]any:
if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) {
case string:
return strings.TrimSpace(v), nil
case map[string]any:
if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil
}
}
}
return "", errors.New("onboardUser returned unsupported project_id format")
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
return "", errors.New("onboardUser completed but no project_id returned")
}
time.Sleep(2 * time.Second)
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
}
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
}
type googleCloudProject struct {
ProjectID string `json:"projectId"`
DisplayName string `json:"name"`
LifecycleState string `json:"lifecycleState"`
}
type googleCloudProjectsResponse struct {
Projects []googleCloudProject `json:"projects"`
}
func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
if err != nil {
return "", fmt.Errorf("failed to create resource manager request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client := &http.Client{Timeout: 30 * time.Second}
if strings.TrimSpace(proxyURL) != "" {
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil {
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)}
}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("resource manager request failed: %w", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read resource manager response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("resource manager HTTP %d: %s", resp.StatusCode, string(bodyBytes))
}
var projectsResp googleCloudProjectsResponse
if err := json.Unmarshal(bodyBytes, &projectsResp); err != nil {
return "", fmt.Errorf("failed to parse resource manager response: %w", err)
}
active := make([]googleCloudProject, 0, len(projectsResp.Projects))
for _, p := range projectsResp.Projects {
if p.LifecycleState == "ACTIVE" && strings.TrimSpace(p.ProjectID) != "" {
active = append(active, p)
}
}
if len(active) == 0 {
return "", errors.New("no ACTIVE projects found from resource manager")
}
// Prefer likely companion projects first.
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "cloud-ai-companion") || strings.Contains(name, "cloud ai companion") || strings.Contains(name, "code assist") {
return strings.TrimSpace(p.ProjectID), nil
}
}
// Then prefer "default".
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "default") || strings.Contains(name, "default") {
return strings.TrimSpace(p.ProjectID), nil
}
}
return strings.TrimSpace(active[0].ProjectID), nil
}
......@@ -3,6 +3,7 @@ package service
import (
"context"
"errors"
"log"
"strconv"
"strings"
"time"
......@@ -95,6 +96,40 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
return "", errors.New("access_token not found in credentials")
}
// project_id is optional now:
// - If present: will use Code Assist API (requires project_id)
// - If absent: will use AI Studio API with OAuth token (like regular API key mode)
// Auto-detect project_id only if explicitly enabled via a credential flag
projectID := strings.TrimSpace(account.GetCredential("project_id"))
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
if projectID == "" && autoDetectProjectID {
if p.geminiOAuthService == nil {
return accessToken, nil // Fallback to AI Studio API mode
}
var proxyURL string
if account.ProxyID != nil && p.geminiOAuthService.proxyRepo != nil {
if proxy, err := p.geminiOAuthService.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
if err != nil {
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
return accessToken, nil
}
detected = strings.TrimSpace(detected)
if detected != "" {
if account.Credentials == nil {
account.Credentials = model.JSONB{}
}
account.Credentials["project_id"] = detected
_ = p.accountRepo.Update(ctx, account)
}
}
// 3) Populate cache with TTL.
if p.tokenCache != nil {
ttl := 30 * time.Minute
......
......@@ -166,9 +166,18 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
// Same priority, select least recently used
if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
default:
// Same priority, select least recently used
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment