"frontend/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "fbffb08aae4a8bea2db6a6392c9ef30f777a14d0"
Commit 429f38d0 authored by shaw's avatar shaw
Browse files

Merge PR #37: Add Gemini OAuth and Messages Compat Support

parents 2d89f366 2714be99
......@@ -21,6 +21,9 @@ coverage.html
# 依赖(使用 go mod)
vendor/
# Go 编译缓存
backend/.gocache/
# ===================
# Node.js / Vue 前端
# ===================
......
......@@ -69,6 +69,7 @@ func provideCleanup(
emailQueue *service.EmailQueueService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
......@@ -99,6 +100,10 @@ func provideCleanup(
openaiOAuth.Stop()
return nil
}},
{"GeminiOAuthService", func() error {
geminiOAuth.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
......
......@@ -80,17 +80,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
geminiTokenCache := repository.NewGeminiTokenCache(client)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService)
......@@ -101,7 +107,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
......@@ -112,7 +118,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
......@@ -120,10 +127,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
......@@ -153,6 +160,7 @@ func provideCleanup(
emailQueue *service.EmailQueueService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
......@@ -182,6 +190,10 @@ func provideCleanup(
openaiOAuth.Stop()
return nil
}},
{"GeminiOAuthService", func() error {
geminiOAuth.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
......
......@@ -18,6 +18,17 @@ type Config struct {
Gateway GatewayConfig `mapstructure:"gateway"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
}
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
}
type GeminiOAuthConfig struct {
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
Scopes string `mapstructure:"scopes"`
}
// TokenRefreshConfig OAuth token自动刷新配置
......@@ -211,9 +222,16 @@ func setDefaults() {
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
// Default: uses Gemini CLI public credentials (set via environment)
viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
}
func (c *Config) Validate() error {
......
......@@ -2,9 +2,11 @@ package admin
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
......@@ -30,6 +32,7 @@ type AccountHandler struct {
adminService service.AdminService
oauthService *service.OAuthService
openaiOAuthService *service.OpenAIOAuthService
geminiOAuthService *service.GeminiOAuthService
rateLimitService *service.RateLimitService
accountUsageService *service.AccountUsageService
accountTestService *service.AccountTestService
......@@ -42,6 +45,7 @@ func NewAccountHandler(
adminService service.AdminService,
oauthService *service.OAuthService,
openaiOAuthService *service.OpenAIOAuthService,
geminiOAuthService *service.GeminiOAuthService,
rateLimitService *service.RateLimitService,
accountUsageService *service.AccountUsageService,
accountTestService *service.AccountTestService,
......@@ -52,6 +56,7 @@ func NewAccountHandler(
adminService: adminService,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
rateLimitService: rateLimitService,
accountUsageService: accountUsageService,
accountTestService: accountTestService,
......@@ -345,6 +350,19 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials[k] = v
}
}
} else if account.Platform == service.PlatformGemini {
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
} else {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
......@@ -362,10 +380,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
// Update token-related fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
newCredentials["refresh_token"] = tokenInfo.RefreshToken
newCredentials["scope"] = tokenInfo.Scope
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if strings.TrimSpace(tokenInfo.Scope) != "" {
newCredentials["scope"] = tokenInfo.Scope
}
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
......@@ -858,6 +880,44 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
// Handle Gemini accounts
if account.IsGemini() {
// For OAuth accounts: return default Gemini models
if account.IsOAuth() {
response.Success(c, geminicli.DefaultModels)
return
}
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, geminicli.DefaultModels)
return
}
var models []geminicli.Model
for requestedModel := range mapping {
var found bool
for _, dm := range geminicli.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
if !found {
models = append(models, geminicli.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
CreatedAt: "",
})
}
}
response.Success(c, models)
return
}
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
......
package admin
import (
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type GeminiOAuthHandler struct {
geminiOAuthService *service.GeminiOAuthService
}
func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *GeminiOAuthHandler {
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
}
// GET /api/v1/admin/gemini/oauth/capabilities
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
cfg := h.geminiOAuthService.GetOAuthConfig()
response.Success(c, cfg)
}
type GeminiGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
ProjectID string `json:"project_id"`
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
// 默认为 "code_assist" 以保持向后兼容
OAuthType string `json:"oauth_type"`
}
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
// POST /api/v1/admin/gemini/oauth/auth-url
func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req GeminiGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
return
}
// Always pass the "hosted" callback URI; the OAuth service may override it depending on
// oauth_type and whether the built-in Gemini CLI OAuth client is used.
redirectURI := deriveGeminiRedirectURI(c)
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType)
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}
response.InternalError(c, "Failed to generate auth URL: "+msg)
return
}
response.Success(c, result)
}
type GeminiExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
OAuthType string `json:"oauth_type"`
}
// ExchangeCode exchanges authorization code for tokens.
// POST /api/v1/admin/gemini/oauth/exchange-code
func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
var req GeminiExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
return
}
tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
ProxyID: req.ProxyID,
OAuthType: oauthType,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
func deriveGeminiRedirectURI(c *gin.Context) string {
origin := strings.TrimSpace(c.GetHeader("Origin"))
if origin != "" {
return strings.TrimRight(origin, "/") + "/auth/callback"
}
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
if xfProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); xfProto != "" {
scheme = strings.TrimSpace(strings.Split(xfProto, ",")[0])
}
host := strings.TrimSpace(c.Request.Host)
if xfHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); xfHost != "" {
host = strings.TrimSpace(strings.Split(xfHost, ",")[0])
}
return fmt.Sprintf("%s://%s/auth/callback", scheme, host)
}
......@@ -21,15 +21,23 @@ import (
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
geminiCompatService *service.GeminiMessagesCompatService
userService *service.UserService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
}
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(gatewayService *service.GatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService) *GatewayHandler {
func NewGatewayHandler(
gatewayService *service.GatewayService,
geminiCompatService *service.GeminiMessagesCompatService,
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
) *GatewayHandler {
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
......@@ -114,8 +122,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
platform := ""
if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
var account *service.Account
if platform == service.PlatformGemini {
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
} else {
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
}
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
......@@ -143,7 +161,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
var result *service.ForwardResult
if platform == service.PlatformGemini {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
}
if err != nil {
// 错误响应已在Forward中处理,这里只记录日志
log.Printf("Forward request failed: %v", err)
......
package handler
import (
"context"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models")
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaGetModel proxies:
// GET /v1beta/models/{model}
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
modelName := strings.TrimSpace(c.Param("model"))
if modelName == "" {
googleError(c, http.StatusBadRequest, "Missing model in URL")
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName)
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaModels proxies Gemini native REST endpoints like:
// POST /v1beta/models/{model}:generateContent
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
if err != nil {
googleError(c, http.StatusNotFound, err.Error())
return
}
stream := action == "streamGenerateContent"
body, err := io.ReadAll(c.Request.Body)
if err != nil {
googleError(c, http.StatusBadRequest, "Failed to read request body")
return
}
if len(body) == 0 {
googleError(c, http.StatusBadRequest, "Request body is empty")
return
}
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
} else if !canWait {
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
// 1) user concurrency slot
streamStarted := false
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
return
}
// 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body)
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// 5) forward (writes response to client)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
if err != nil {
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
return
}
// 6) record usage async
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
}
func parseGeminiModelAction(rest string) (model string, action string, err error) {
rest = strings.TrimSpace(rest)
if rest == "" {
return "", "", &pathParseError{"missing path"}
}
// Standard: {model}:{action}
if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
// Fallback: {model}/{action}
if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
return "", "", &pathParseError{"invalid model action path"}
}
type pathParseError struct{ msg string }
func (e *pathParseError) Error() string { return e.msg }
func googleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
}
func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
if res == nil {
googleError(c, http.StatusBadGateway, "Empty upstream response")
return
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
continue
}
for _, v := range vv {
c.Writer.Header().Add(k, v)
}
}
contentType := res.Headers.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(res.StatusCode, contentType, res.Body)
}
func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
if res == nil {
return true
}
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
return false
}
if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") {
return true
}
return false
}
......@@ -12,6 +12,7 @@ type AdminHandlers struct {
Account *admin.AccountHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Setting *admin.SettingHandler
......
......@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
accountHandler *admin.AccountHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
settingHandler *admin.SettingHandler,
......@@ -29,6 +30,7 @@ func ProvideAdminHandlers(
Account: accountHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Setting: settingHandler,
......@@ -95,6 +97,7 @@ var ProviderSet = wire.NewSet(
admin.NewAccountHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewSettingHandler,
......
package gemini
// This package provides minimal fallback model metadata for Gemini native endpoints.
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
type Model struct {
Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"`
Description string `json:"description,omitempty"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
}
type ModelsListResponse struct {
Models []Model `json:"models"`
}
func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"}
return []Model{
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash-lite", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
}
}
func FallbackModelsList() ModelsListResponse {
return ModelsListResponse{Models: DefaultModels()}
}
func FallbackModel(model string) Model {
methods := []string{"generateContent", "streamGenerateContent"}
if model == "" {
return Model{Name: "models/unknown", SupportedGenerationMethods: methods}
}
if len(model) >= 7 && model[:7] == "models/" {
return Model{Name: model, SupportedGenerationMethods: methods}
}
return Model{Name: "models/" + model, SupportedGenerationMethods: methods}
}
package geminicli
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
type LoadCodeAssistRequest struct {
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type LoadCodeAssistMetadata struct {
IDEType string `json:"ideType"`
Platform string `json:"platform"`
PluginType string `json:"pluginType"`
}
type LoadCodeAssistResponse struct {
CurrentTier string `json:"currentTier,omitempty"`
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
}
type AllowedTier struct {
ID string `json:"id"`
IsDefault bool `json:"isDefault,omitempty"`
}
type OnboardUserRequest struct {
TierID string `json:"tierId"`
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type OnboardUserResponse struct {
Done bool `json:"done"`
Response *OnboardUserResultData `json:"response,omitempty"`
Name string `json:"name,omitempty"`
}
type OnboardUserResultData struct {
CloudAICompanionProject any `json:"cloudaicompanionProject,omitempty"`
}
package geminicli
import "time"
const (
AIStudioBaseURL = "https://generativelanguage.googleapis.com"
GeminiCliBaseURL = "https://cloudcode-pa.googleapis.com"
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
TokenURL = "https://oauth2.googleapis.com/token"
// AIStudioOAuthRedirectURI is the default redirect URI used for AI Studio OAuth.
// This matches the "copy/paste callback URL" flow used by OpenAI OAuth in this project.
// Note: You still need to register this redirect URI in your Google OAuth client
// unless you use an OAuth client type that permits localhost redirect URIs.
AIStudioOAuthRedirectURI = "http://localhost:1455/auth/callback"
// DefaultScopes for Code Assist (includes cloud-platform for API access plus userinfo scopes)
// Required by Google's Code Assist API.
DefaultCodeAssistScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// DefaultScopes for AI Studio (uses generativelanguage API with OAuth)
// Reference: https://ai.google.dev/gemini-api/docs/oauth
// For regular Google accounts, supports API calls to generativelanguage.googleapis.com
// Note: Google Auth platform currently documents the OAuth scope as
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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"
SessionTTL = 30 * time.Minute
// GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints.
GeminiCLIUserAgent = "GeminiCLI/0.1.5 (Windows; AMD64)"
)
package geminicli
// Model represents a selectable Gemini model for UI/testing purposes.
// Keep JSON fields consistent with existing frontend expectations.
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-3-pro", Type: "model", DisplayName: "Gemini 3 Pro", CreatedAt: ""},
{ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.
const DefaultTestModel = "gemini-2.5-pro"
package geminicli
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
type OAuthConfig struct {
ClientID string
ClientSecret string
Scopes string
}
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
CreatedAt time.Time `json:"created_at"`
}
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
go store.cleanup()
return store
}
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Stop() {
select {
case <-s.stopCh:
return
default:
close(s.stopCh)
}
}
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars).
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// EffectiveOAuthConfig returns the effective OAuth configuration.
// oauthType: "code_assist" or "ai_studio" (defaults to "code_assist" if empty).
//
// If ClientID/ClientSecret is not provided, this falls back to the built-in Gemini CLI OAuth client.
//
// Note: The built-in Gemini CLI OAuth client is restricted and may reject some scopes (e.g.
// https://www.googleapis.com/auth/generative-language), which will surface as
// "restricted_client" / "Unregistered scope(s)" errors during browser authorization.
func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error) {
effective := OAuthConfig{
ClientID: strings.TrimSpace(cfg.ClientID),
ClientSecret: strings.TrimSpace(cfg.ClientSecret),
Scopes: strings.TrimSpace(cfg.Scopes),
}
// Normalize scopes: allow comma-separated input but send space-delimited scopes to Google.
if effective.Scopes != "" {
effective.Scopes = strings.Join(strings.Fields(strings.ReplaceAll(effective.Scopes, ",", " ")), " ")
}
// Fall back to built-in Gemini CLI OAuth client when not configured.
if effective.ClientID == "" && effective.ClientSecret == "" {
effective.ClientID = GeminiCLIOAuthClientID
effective.ClientSecret = GeminiCLIOAuthClientSecret
} else if effective.ClientID == "" || effective.ClientSecret == "" {
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
}
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
effective.ClientSecret == GeminiCLIOAuthClientSecret
if effective.Scopes == "" {
// Use different default scopes based on OAuth type
if oauthType == "ai_studio" {
// Built-in client can't request some AI Studio scopes (notably generative-language).
if isBuiltinClient {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultAIStudioScopes
}
} else {
// Default to Code Assist scopes
effective.Scopes = DefaultCodeAssistScopes
}
} else if oauthType == "ai_studio" && isBuiltinClient {
// If user overrides scopes while still using the built-in client, strip restricted scopes.
parts := strings.Fields(effective.Scopes)
filtered := make([]string, 0, len(parts))
for _, s := range parts {
if strings.Contains(s, "generative-language") {
continue
}
filtered = append(filtered, s)
}
if len(filtered) == 0 {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = strings.Join(filtered, " ")
}
}
// Backward compatibility: normalize older AI Studio scope to the currently documented one.
if oauthType == "ai_studio" && effective.Scopes != "" {
parts := strings.Fields(effective.Scopes)
for i := range parts {
if parts[i] == "https://www.googleapis.com/auth/generative-language" {
parts[i] = "https://www.googleapis.com/auth/generative-language.retriever"
}
}
effective.Scopes = strings.Join(parts, " ")
}
return effective, nil
}
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
if err != nil {
return "", err
}
redirectURI = strings.TrimSpace(redirectURI)
if redirectURI == "" {
return "", fmt.Errorf("redirect_uri is required")
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", effectiveCfg.ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", effectiveCfg.Scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("access_type", "offline")
params.Set("prompt", "consent")
params.Set("include_granted_scopes", "true")
if strings.TrimSpace(projectID) != "" {
params.Set("project_id", strings.TrimSpace(projectID))
}
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil
}
package geminicli
import "strings"
const maxLogBodyLen = 2048
func SanitizeBodyForLogs(body string) string {
body = truncateBase64InMessage(body)
if len(body) > maxLogBodyLen {
body = body[:maxLogBodyLen] + "...[truncated]"
}
return body
}
func truncateBase64InMessage(message string) string {
const maxBase64Length = 50
result := message
offset := 0
for {
idx := strings.Index(result[offset:], ";base64,")
if idx == -1 {
break
}
actualIdx := offset + idx
start := actualIdx + len(";base64,")
end := start
for end < len(result) && isBase64Char(result[end]) {
end++
}
if end-start > maxBase64Length {
result = result[:start+maxBase64Length] + "...[truncated]" + result[end:]
offset = start + maxBase64Length + len("...[truncated]")
continue
}
offset = end
}
return result
}
func isBase64Char(c byte) bool {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '='
}
package geminicli
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope,omitempty"`
}
package googleapi
import "net/http"
// HTTPStatusToGoogleStatus maps HTTP status codes to Google-style error status strings.
func HTTPStatusToGoogleStatus(status int) string {
switch status {
case http.StatusBadRequest:
return "INVALID_ARGUMENT"
case http.StatusUnauthorized:
return "UNAUTHENTICATED"
case http.StatusForbidden:
return "PERMISSION_DENIED"
case http.StatusNotFound:
return "NOT_FOUND"
case http.StatusTooManyRequests:
return "RESOURCE_EXHAUSTED"
default:
if status >= 500 {
return "INTERNAL"
}
return "UNKNOWN"
}
}
package repository
import (
"context"
"fmt"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
)
type geminiOAuthClient struct {
tokenURL string
cfg *config.Config
}
func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
return &geminiOAuthClient{
tokenURL: geminicli.TokenURL,
cfg: cfg,
}
}
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
if err != nil {
return nil, err
}
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", oauthCfg.ClientID)
formData.Set("client_secret", oauthCfg.ClientSecret)
formData.Set("code", code)
formData.Set("code_verifier", codeVerifier)
formData.Set("redirect_uri", redirectURI)
var tokenResp geminicli.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(c.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
}
return &tokenResp, nil
}
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
if err != nil {
return nil, err
}
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauthCfg.ClientID)
formData.Set("client_secret", oauthCfg.ClientSecret)
var tokenResp geminicli.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(c.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
}
return &tokenResp, nil
}
func createGeminiReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
geminiTokenKeyPrefix = "gemini:token:"
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
)
type geminiTokenCache struct {
rdb *redis.Client
}
func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
return &geminiTokenCache{rdb: rdb}
}
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result()
}
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment