Commit 11bfc807 authored by song's avatar song
Browse files

merge upstream/main

parents c2a6ca8d 62dc0b95
...@@ -56,6 +56,9 @@ var ( ...@@ -56,6 +56,9 @@ var (
} }
) )
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// allowedHeaders 白名单headers(参考CRS项目) // allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{ var allowedHeaders = map[string]bool{
"accept": true, "accept": true,
...@@ -80,9 +83,17 @@ var allowedHeaders = map[string]bool{ ...@@ -80,9 +83,17 @@ var allowedHeaders = map[string]bool{
// GatewayCache defines cache operations for gateway service // GatewayCache defines cache operations for gateway service
type GatewayCache interface { type GatewayCache interface {
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
func derefGroupID(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
} }
type AccountWaitPlan struct { type AccountWaitPlan struct {
...@@ -225,11 +236,11 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { ...@@ -225,11 +236,11 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
} }
// BindStickySession sets session -> account binding with standard TTL. // BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 || s.cache == nil { if sessionHash == "" || accountID <= 0 || s.cache == nil {
return nil return nil
} }
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
} }
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
...@@ -356,6 +367,21 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -356,6 +367,21 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return nil, fmt.Errorf("get group failed: %w", err) return nil, fmt.Errorf("get group failed: %w", err)
} }
platform = group.Platform platform = group.Platform
// 检查 Claude Code 客户端限制
if group.ClaudeCodeOnly {
isClaudeCode := IsClaudeCodeClient(ctx)
if !isClaudeCode {
// 非 Claude Code 客户端,检查是否有降级分组
if group.FallbackGroupID != nil {
// 使用降级分组重新调度
fallbackGroupID := *group.FallbackGroupID
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
}
// 无降级分组,拒绝访问
return nil, ErrClaudeCodeOnly
}
}
} else { } else {
// 无分组时只使用原生 anthropic 平台 // 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic platform = PlatformAnthropic
...@@ -377,10 +403,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -377,10 +403,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID stickyAccountID = accountID
} }
} }
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
if err != nil {
return nil, err
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled { if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil { if err != nil {
...@@ -443,7 +476,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -443,7 +476,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// ============ Layer 1: 粘性会话优先 ============ // ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) { if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountInGroup(account, groupID) && if err == nil && s.isAccountInGroup(account, groupID) &&
...@@ -452,7 +485,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -452,7 +485,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
Acquired: true, Acquired: true,
...@@ -509,7 +542,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -509,7 +542,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil { if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
return result, nil return result, nil
} }
} else { } else {
...@@ -559,7 +592,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -559,7 +592,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: item.account, Account: item.account,
...@@ -587,7 +620,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -587,7 +620,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts") return nil, errors.New("no available accounts")
} }
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...) ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
...@@ -595,7 +628,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates ...@@ -595,7 +628,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: acc, Account: acc,
...@@ -622,6 +655,42 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { ...@@ -622,6 +655,42 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
} }
} }
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
if groupID == nil {
return groupID, nil
}
// 强制平台模式不检查 Claude Code 限制
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
return groupID, nil
}
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
if !group.ClaudeCodeOnly {
return groupID, nil
}
// 分组启用了 Claude Code 限制
if IsClaudeCodeClient(ctx) {
return groupID, nil
}
// 非 Claude Code 客户端,检查降级分组
if group.FallbackGroupID != nil {
return group.FallbackGroupID, nil
}
return nil, ErrClaudeCodeOnly
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" { if hasForcePlatform && forcePlatform != "" {
...@@ -741,13 +810,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -741,13 +810,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
} }
return account, nil return account, nil
...@@ -817,7 +886,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -817,7 +886,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 4. 建立粘性绑定 // 4. 建立粘性绑定
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
} }
} }
...@@ -833,14 +902,14 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -833,14 +902,14 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
} }
return account, nil return account, nil
...@@ -912,7 +981,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -912,7 +981,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
// 4. 建立粘性绑定 // 4. 建立粘性绑定
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
} }
} }
......
...@@ -109,7 +109,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -109,7 +109,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
cacheKey := "gemini:" + sessionHash cacheKey := "gemini:" + sessionHash
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
...@@ -133,7 +133,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -133,7 +133,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
} }
} }
if usable { if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account, nil return account, nil
} }
} }
...@@ -220,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -220,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
} }
if sessionHash != "" { if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
} }
return selected, nil return selected, nil
......
...@@ -172,7 +172,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([ ...@@ -172,7 +172,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil } func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
...@@ -196,14 +196,14 @@ type mockGatewayCacheForGemini struct { ...@@ -196,14 +196,14 @@ type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64 sessionBindings map[string]int64
} }
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok { if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil return id, nil
} }
return 0, errors.New("not found") return 0, errors.New("not found")
} }
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil { if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64) m.sessionBindings = make(map[string]int64)
} }
...@@ -211,7 +211,7 @@ func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, ses ...@@ -211,7 +211,7 @@ func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, ses
return nil return nil
} }
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil return nil
} }
......
...@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 ...@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
} }
// OAuth client selection: // OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret. // - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client. // - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client. // - ai_studio: requires a user-provided OAuth client
oauthCfg := geminicli.OAuthConfig{ oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID, ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes, Scopes: s.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" { if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfg.ClientID = "" oauthCfg.ClientID = ""
oauthCfg.ClientSecret = "" oauthCfg.ClientSecret = ""
} }
...@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
case "google_one": case "google_one":
log.Printf("[GeminiOAuth] Processing google_one OAuth type") log.Printf("[GeminiOAuth] Processing google_one OAuth type")
// Google One accounts use cloudaicompanion API, which requires a project_id.
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
if projectID == "" {
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
var err error
projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err)
}
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID)
}
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
// Attempt to fetch Drive storage tier // Attempt to fetch Drive storage tier
var storageInfo *geminicli.DriveStorageInfo var storageInfo *geminicli.DriveStorageInfo
......
...@@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { ...@@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
wantProjectID: "", wantProjectID: "",
}, },
{ {
name: "google_one uses custom client when configured and redirects to localhost", name: "google_one always forces built-in client even when custom client configured",
cfg: &config.Config{ cfg: &config.Config{
Gemini: config.GeminiConfig{ Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{ OAuth: config.GeminiOAuthConfig{
...@@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { ...@@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
}, },
}, },
oauthType: "google_one", oauthType: "google_one",
wantClientID: "custom-client-id", wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.AIStudioOAuthRedirectURI, wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultGoogleOneScopes, wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "", wantProjectID: "",
}, },
{ {
......
...@@ -22,6 +22,10 @@ type Group struct { ...@@ -22,6 +22,10 @@ type Group struct {
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
......
...@@ -21,7 +21,7 @@ type GroupRepository interface { ...@@ -21,7 +21,7 @@ type GroupRepository interface {
DeleteCascade(ctx context.Context, id int64) ([]int64, error) DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]Group, error) ListActive(ctx context.Context) ([]Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
......
...@@ -134,11 +134,11 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { ...@@ -134,11 +134,11 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
} }
// BindStickySession sets session -> account binding with standard TTL. // BindStickySession sets session -> account binding with standard TTL.
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 { if sessionHash == "" || accountID <= 0 {
return nil return nil
} }
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL) return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL)
} }
// SelectAccount selects an OpenAI account with sticky session support // SelectAccount selects an OpenAI account with sticky session support
...@@ -155,13 +155,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI ...@@ -155,13 +155,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. Check sticky session // 1. Check sticky session
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL // Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil return account, nil
} }
} }
...@@ -227,7 +227,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C ...@@ -227,7 +227,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
// 4. Set sticky session // 4. Set sticky session
if sessionHash != "" { if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
} }
return selected, nil return selected, nil
...@@ -238,7 +238,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -238,7 +238,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil {
stickyAccountID = accountID stickyAccountID = accountID
} }
} }
...@@ -298,14 +298,14 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -298,14 +298,14 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
// ============ Layer 1: Sticky session ============ // ============ Layer 1: Sticky session ============
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) { if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) { (requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
Acquired: true, Acquired: true,
...@@ -362,7 +362,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -362,7 +362,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
} }
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: acc, Account: acc,
...@@ -415,7 +415,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -415,7 +415,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
} }
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: item.account, Account: item.account,
...@@ -540,10 +540,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -540,10 +540,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified = true bodyModified = true
} }
// For OAuth accounts using ChatGPT internal API, add store: false // For OAuth accounts using ChatGPT internal API:
// 1. Add store: false
// 2. Normalize input format for Codex API compatibility
if account.Type == AccountTypeOAuth { if account.Type == AccountTypeOAuth {
reqBody["store"] = false reqBody["store"] = false
bodyModified = true bodyModified = true
// Normalize input format: convert AI SDK multi-part content format to simplified format
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
// Codex API expects: {"content": "..."}
if normalizeInputForCodexAPI(reqBody) {
bodyModified = true
}
} }
// Re-serialize body only if modified // Re-serialize body only if modified
...@@ -1085,6 +1094,101 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel ...@@ -1085,6 +1094,101 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
return newBody return newBody
} }
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
// that the ChatGPT internal Codex API expects.
//
// AI SDK sends content as an array of typed objects:
//
// {"content": [{"type": "input_text", "text": "hello"}]}
//
// ChatGPT Codex API expects content as a simple string:
//
// {"content": "hello"}
//
// This function modifies reqBody in-place and returns true if any modification was made.
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
input, ok := reqBody["input"]
if !ok {
return false
}
// Handle case where input is a simple string (already compatible)
if _, isString := input.(string); isString {
return false
}
// Handle case where input is an array of messages
inputArray, ok := input.([]any)
if !ok {
return false
}
modified := false
for _, item := range inputArray {
message, ok := item.(map[string]any)
if !ok {
continue
}
content, ok := message["content"]
if !ok {
continue
}
// If content is already a string, no conversion needed
if _, isString := content.(string); isString {
continue
}
// If content is an array (AI SDK format), convert to string
contentArray, ok := content.([]any)
if !ok {
continue
}
// Extract text from content array
var textParts []string
for _, part := range contentArray {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
// Handle different content types
partType, _ := partMap["type"].(string)
switch partType {
case "input_text", "text":
// Extract text from input_text or text type
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
case "input_image", "image":
// For images, we need to preserve the original format
// as ChatGPT Codex API may support images in a different way
// For now, skip image parts (they will be lost in conversion)
// TODO: Consider preserving image data or handling it separately
continue
case "input_file", "file":
// Similar to images, file inputs may need special handling
continue
default:
// For unknown types, try to extract text if available
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
}
}
// Convert content array to string
if len(textParts) > 0 {
message["content"] = strings.Join(textParts, "\n")
modified = true
}
}
return modified
}
// OpenAIRecordUsageInput input for recording usage // OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct { type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult Result *OpenAIForwardResult
......
...@@ -20,6 +20,7 @@ type ProxyRepository interface { ...@@ -20,6 +20,7 @@ type ProxyRepository interface {
List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error)
ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]Proxy, error) ListActive(ctx context.Context) ([]Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...@@ -64,6 +65,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -64,6 +65,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyAPIBaseURL, SettingKeyAPIBaseURL,
SettingKeyContactInfo, SettingKeyContactInfo,
SettingKeyDocURL, SettingKeyDocURL,
SettingKeyLinuxDoConnectEnabled,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -71,6 +73,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -71,6 +73,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
return nil, fmt.Errorf("get public settings: %w", err) return nil, fmt.Errorf("get public settings: %w", err)
} }
linuxDoEnabled := false
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
linuxDoEnabled = raw == "true"
} else {
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
}
return &PublicSettings{ return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
...@@ -82,6 +91,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -82,6 +91,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
APIBaseURL: settings[SettingKeyAPIBaseURL], APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo], ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL], DocURL: settings[SettingKeyDocURL],
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil }, nil
} }
...@@ -111,6 +121,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -111,6 +121,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
} }
// LinuxDo Connect OAuth 登录(终端用户 SSO)
updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled)
updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID
updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL
if settings.LinuxDoConnectClientSecret != "" {
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
}
// OEM设置 // OEM设置
updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo updates[SettingKeySiteLogo] = settings.SiteLogo
...@@ -141,8 +159,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -141,8 +159,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
if err != nil { if err != nil {
// 默认开放注册 // 安全默认:如果设置不存在或查询出错,默认关闭注册
return true return false
} }
return value == "true" return value == "true"
} }
...@@ -271,6 +289,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -271,6 +289,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.SMTPPassword = settings[SettingKeySMTPPassword] result.SMTPPassword = settings[SettingKeySMTPPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey] result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
// LinuxDo Connect 设置:
// - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭)
// - 支持在后台“系统设置”中覆盖并持久化(存储于 DB)
linuxDoBase := config.LinuxDoConnectConfig{}
if s.cfg != nil {
linuxDoBase = s.cfg.LinuxDo
}
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
result.LinuxDoConnectEnabled = raw == "true"
} else {
result.LinuxDoConnectEnabled = linuxDoBase.Enabled
}
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
result.LinuxDoConnectClientID = strings.TrimSpace(v)
} else {
result.LinuxDoConnectClientID = linuxDoBase.ClientID
}
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
result.LinuxDoConnectRedirectURL = strings.TrimSpace(v)
} else {
result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL
}
result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret])
if result.LinuxDoConnectClientSecret == "" {
result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret)
}
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
// Model fallback settings // Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
...@@ -289,6 +339,99 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -289,6 +339,99 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
return result return result
} }
// GetLinuxDoConnectOAuthConfig 返回用于登录的“最终生效” LinuxDo Connect 配置。
//
// 优先级:
// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值
// - 否则回退到 config.yaml/env 的值
func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
if s == nil || s.cfg == nil {
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
effective := s.cfg.LinuxDo
keys := []string{
SettingKeyLinuxDoConnectEnabled,
SettingKeyLinuxDoConnectClientID,
SettingKeyLinuxDoConnectClientSecret,
SettingKeyLinuxDoConnectRedirectURL,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return config.LinuxDoConnectConfig{}, fmt.Errorf("get linuxdo connect settings: %w", err)
}
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
effective.Enabled = raw == "true"
}
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
effective.ClientID = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyLinuxDoConnectClientSecret]; ok && strings.TrimSpace(v) != "" {
effective.ClientSecret = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
if !effective.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
// 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。
if strings.TrimSpace(effective.ClientID) == "" {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
}
if strings.TrimSpace(effective.AuthorizeURL) == "" {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured")
}
if strings.TrimSpace(effective.TokenURL) == "" {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured")
}
if strings.TrimSpace(effective.UserInfoURL) == "" {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url not configured")
}
if strings.TrimSpace(effective.RedirectURL) == "" {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")
}
if strings.TrimSpace(effective.FrontendRedirectURL) == "" {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured")
}
if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid")
}
if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid")
}
if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid")
}
if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid")
}
if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
}
method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic":
if strings.TrimSpace(effective.ClientSecret) == "" {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
if !effective.UsePKCE {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default:
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
return effective, nil
}
// getStringOrDefault 获取字符串值或默认值 // getStringOrDefault 获取字符串值或默认值
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
if value, ok := settings[key]; ok && value != "" { if value, ok := settings[key]; ok && value != "" {
......
...@@ -18,6 +18,13 @@ type SystemSettings struct { ...@@ -18,6 +18,13 @@ type SystemSettings struct {
TurnstileSecretKey string TurnstileSecretKey string
TurnstileSecretKeyConfigured bool TurnstileSecretKeyConfigured bool
// LinuxDo Connect OAuth 登录(终端用户 SSO)
LinuxDoConnectEnabled bool
LinuxDoConnectClientID string
LinuxDoConnectClientSecret string
LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string
SiteName string SiteName string
SiteLogo string SiteLogo string
SiteSubtitle string SiteSubtitle string
...@@ -51,5 +58,6 @@ type PublicSettings struct { ...@@ -51,5 +58,6 @@ type PublicSettings struct {
APIBaseURL string APIBaseURL string
ContactInfo string ContactInfo string
DocURL string DocURL string
LinuxDoOAuthEnabled bool
Version string Version string
} }
-- 029_add_group_claude_code_restriction.sql
-- 添加分组级别的 Claude Code 客户端限制功能
-- 添加 claude_code_only 字段:是否仅允许 Claude Code 客户端
ALTER TABLE groups
ADD COLUMN IF NOT EXISTS claude_code_only BOOLEAN NOT NULL DEFAULT FALSE;
-- 添加 fallback_group_id 字段:非 Claude Code 请求降级到的分组
ALTER TABLE groups
ADD COLUMN IF NOT EXISTS fallback_group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL;
-- 添加索引优化查询
CREATE INDEX IF NOT EXISTS idx_groups_claude_code_only
ON groups(claude_code_only) WHERE deleted_at IS NULL;
CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id
ON groups(fallback_group_id) WHERE deleted_at IS NULL AND fallback_group_id IS NOT NULL;
-- 添加字段注释
COMMENT ON COLUMN groups.claude_code_only IS '是否仅允许 Claude Code 客户端访问此分组';
COMMENT ON COLUMN groups.fallback_group_id IS '非 Claude Code 请求降级使用的分组 ID';
...@@ -234,6 +234,31 @@ jwt: ...@@ -234,6 +234,31 @@ jwt:
# 令牌过期时间(小时,最大 24) # 令牌过期时间(小时,最大 24)
expire_hour: 24 expire_hour: 24
# =============================================================================
# LinuxDo Connect OAuth Login (SSO)
# LinuxDo Connect OAuth 登录(用于 Sub2API 用户登录)
# =============================================================================
linuxdo_connect:
enabled: false
client_id: ""
client_secret: ""
authorize_url: "https://connect.linux.do/oauth2/authorize"
token_url: "https://connect.linux.do/oauth2/token"
userinfo_url: "https://connect.linux.do/api/user"
scopes: "user"
# 示例: "https://your-domain.com/api/v1/auth/oauth/linuxdo/callback"
redirect_url: ""
# 安全提示:
# - 建议使用同源相对路径(以 / 开头),避免把 token 重定向到意外的第三方域名
# - 该地址不应包含 #fragment(本实现使用 URL fragment 传递 access_token)
frontend_redirect_url: "/auth/linuxdo/callback"
token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none
# 注意:当 token_auth_method=none(public client)时,必须启用 PKCE
use_pkce: false
userinfo_email_path: ""
userinfo_id_path: ""
userinfo_username_path: ""
# ============================================================================= # =============================================================================
# Default Settings # Default Settings
# 默认设置 # 默认设置
......
...@@ -173,11 +173,12 @@ services: ...@@ -173,11 +173,12 @@ services:
volumes: volumes:
- redis_data:/data - redis_data:/data
command: > command: >
redis-server sh -c '
--save 60 1 redis-server
--appendonly yes --save 60 1
--appendfsync everysec --appendonly yes
${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}} --appendfsync everysec
${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}'
environment: environment:
- TZ=${TZ:-Asia/Shanghai} - TZ=${TZ:-Asia/Shanghai}
# REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag) # REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag)
......
...@@ -16,7 +16,7 @@ import type { ...@@ -16,7 +16,7 @@ import type {
* List all groups with pagination * List all groups with pagination
* @param page - Page number (default: 1) * @param page - Page number (default: 1)
* @param pageSize - Items per page (default: 20) * @param pageSize - Items per page (default: 20)
* @param filters - Optional filters (platform, status, is_exclusive) * @param filters - Optional filters (platform, status, is_exclusive, search)
* @returns Paginated list of groups * @returns Paginated list of groups
*/ */
export async function list( export async function list(
...@@ -26,6 +26,7 @@ export async function list( ...@@ -26,6 +26,7 @@ export async function list(
platform?: GroupPlatform platform?: GroupPlatform
status?: 'active' | 'inactive' status?: 'active' | 'inactive'
is_exclusive?: boolean is_exclusive?: boolean
search?: string
}, },
options?: { options?: {
signal?: AbortSignal signal?: AbortSignal
......
...@@ -34,6 +34,11 @@ export interface SystemSettings { ...@@ -34,6 +34,11 @@ export interface SystemSettings {
turnstile_enabled: boolean turnstile_enabled: boolean
turnstile_site_key: string turnstile_site_key: string
turnstile_secret_key_configured: boolean turnstile_secret_key_configured: boolean
// LinuxDo Connect OAuth 登录(终端用户 SSO)
linuxdo_connect_enabled: boolean
linuxdo_connect_client_id: string
linuxdo_connect_client_secret_configured: boolean
linuxdo_connect_redirect_url: string
// Identity patch configuration (Claude -> Gemini) // Identity patch configuration (Claude -> Gemini)
enable_identity_patch: boolean enable_identity_patch: boolean
identity_patch_prompt: string identity_patch_prompt: string
...@@ -60,6 +65,10 @@ export interface UpdateSettingsRequest { ...@@ -60,6 +65,10 @@ export interface UpdateSettingsRequest {
turnstile_enabled?: boolean turnstile_enabled?: boolean
turnstile_site_key?: string turnstile_site_key?: string
turnstile_secret_key?: string turnstile_secret_key?: string
linuxdo_connect_enabled?: boolean
linuxdo_connect_client_id?: string
linuxdo_connect_client_secret?: string
linuxdo_connect_redirect_url?: string
enable_identity_patch?: boolean enable_identity_patch?: boolean
identity_patch_prompt?: string identity_patch_prompt?: string
} }
......
...@@ -166,7 +166,7 @@ ...@@ -166,7 +166,7 @@
> >
<div <div
:class="[ :class="[
'flex h-8 w-8 items-center justify-center rounded-lg', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'oauth-based' accountCategory === 'oauth-based'
? 'bg-orange-500 text-white' ? 'bg-orange-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
...@@ -196,7 +196,7 @@ ...@@ -196,7 +196,7 @@
> >
<div <div
:class="[ :class="[
'flex h-8 w-8 items-center justify-center rounded-lg', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'apikey' accountCategory === 'apikey'
? 'bg-purple-500 text-white' ? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
...@@ -232,7 +232,7 @@ ...@@ -232,7 +232,7 @@
> >
<div <div
:class="[ :class="[
'flex h-8 w-8 items-center justify-center rounded-lg', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'oauth-based' accountCategory === 'oauth-based'
? 'bg-green-500 text-white' ? 'bg-green-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
...@@ -258,7 +258,7 @@ ...@@ -258,7 +258,7 @@
> >
<div <div
:class="[ :class="[
'flex h-8 w-8 items-center justify-center rounded-lg', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'apikey' accountCategory === 'apikey'
? 'bg-purple-500 text-white' ? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
...@@ -302,7 +302,7 @@ ...@@ -302,7 +302,7 @@
> >
<div <div
:class="[ :class="[
'flex h-8 w-8 items-center justify-center rounded-lg', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'oauth-based' accountCategory === 'oauth-based'
? 'bg-blue-500 text-white' ? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
...@@ -332,7 +332,7 @@ ...@@ -332,7 +332,7 @@
> >
<div <div
:class="[ :class="[
'flex h-8 w-8 items-center justify-center rounded-lg', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'apikey' accountCategory === 'apikey'
? 'bg-purple-500 text-white' ? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
...@@ -397,7 +397,7 @@ ...@@ -397,7 +397,7 @@
> >
<div <div
:class="[ :class="[
'flex h-8 w-8 items-center justify-center rounded-lg', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one' geminiOAuthType === 'google_one'
? 'bg-purple-500 text-white' ? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
...@@ -440,7 +440,7 @@ ...@@ -440,7 +440,7 @@
> >
<div <div
:class="[ :class="[
'flex h-8 w-8 items-center justify-center rounded-lg', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'code_assist' geminiOAuthType === 'code_assist'
? 'bg-blue-500 text-white' ? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
...@@ -518,7 +518,7 @@ ...@@ -518,7 +518,7 @@
> >
<div <div
:class="[ :class="[
'flex h-8 w-8 items-center justify-center rounded-lg', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'ai_studio' geminiOAuthType === 'ai_studio'
? 'bg-amber-500 text-white' ? 'bg-amber-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
...@@ -621,7 +621,7 @@ ...@@ -621,7 +621,7 @@
<div <div
class="flex items-center gap-3 rounded-lg border-2 border-purple-500 bg-purple-50 p-3 dark:bg-purple-900/20" class="flex items-center gap-3 rounded-lg border-2 border-purple-500 bg-purple-50 p-3 dark:bg-purple-900/20"
> >
<div class="flex h-8 w-8 items-center justify-center rounded-lg bg-purple-500 text-white"> <div class="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-purple-500 text-white">
<Icon name="key" size="sm" /> <Icon name="key" size="sm" />
</div> </div>
<div> <div>
......
...@@ -73,113 +73,48 @@ ...@@ -73,113 +73,48 @@
</div> </div>
</fieldset> </fieldset>
<!-- Gemini OAuth Type Selection --> <!-- Gemini OAuth Type Display (read-only) -->
<fieldset v-if="isGemini" class="border-0 p-0"> <div v-if="isGemini" class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700">
<legend class="input-label">{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}</legend> <div class="mb-2 text-sm font-medium text-gray-700 dark:text-gray-300">
<div class="mt-2 grid grid-cols-3 gap-3"> {{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
<button </div>
type="button" <div class="flex items-center gap-3">
@click="handleSelectGeminiOAuthType('google_one')" <div
:class="[ :class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one' geminiOAuthType === 'google_one'
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20' ? 'bg-purple-500 text-white'
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700' : geminiOAuthType === 'code_assist'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5">
<path stroke-linecap="round" stroke-linejoin="round" d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" />
</svg>
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">Google One</span>
<span class="text-xs text-gray-500 dark:text-gray-400">个人账号</span>
</div>
</button>
<button
type="button"
@click="handleSelectGeminiOAuthType('code_assist')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
geminiOAuthType === 'code_assist'
? 'border-blue-500 bg-blue-50 dark:bg-blue-900/20'
: 'border-gray-200 hover:border-blue-300 dark:border-dark-600 dark:hover:border-blue-700'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'code_assist'
? 'bg-blue-500 text-white' ? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-amber-500 text-white'
]"
>
<Icon name="cloud" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.gemini.oauthType.builtInTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.gemini.oauthType.builtInDesc') }}
</span>
</div>
</button>
<button
type="button"
:disabled="!geminiAIStudioOAuthEnabled"
@click="handleSelectGeminiOAuthType('ai_studio')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
!geminiAIStudioOAuthEnabled ? 'cursor-not-allowed opacity-60' : '',
geminiOAuthType === 'ai_studio'
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700'
]" ]"
> >
<div <Icon v-if="geminiOAuthType === 'google_one'" name="user" size="sm" />
:class="[ <Icon v-else-if="geminiOAuthType === 'code_assist'" name="cloud" size="sm" />
'flex h-8 w-8 items-center justify-center rounded-lg', <Icon v-else name="sparkles" size="sm" />
geminiOAuthType === 'ai_studio' </div>
? 'bg-purple-500 text-white' <div>
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' <span class="block text-sm font-medium text-gray-900 dark:text-white">
]" {{
> geminiOAuthType === 'google_one'
<Icon name="sparkles" size="sm" /> ? 'Google One'
</div> : geminiOAuthType === 'code_assist'
<div class="min-w-0"> ? t('admin.accounts.gemini.oauthType.builtInTitle')
<span class="block text-sm font-medium text-gray-900 dark:text-white"> : t('admin.accounts.gemini.oauthType.customTitle')
{{ t('admin.accounts.gemini.oauthType.customTitle') }} }}
</span> </span>
<span class="text-xs text-gray-500 dark:text-gray-400"> <span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.gemini.oauthType.customDesc') }} {{
</span> geminiOAuthType === 'google_one'
<div v-if="!geminiAIStudioOAuthEnabled" class="group relative mt-1 inline-block"> ? '个人账号'
<span : geminiOAuthType === 'code_assist'
class="rounded bg-amber-100 px-2 py-0.5 text-xs text-amber-700 dark:bg-amber-900/30 dark:text-amber-300" ? t('admin.accounts.gemini.oauthType.builtInDesc')
> : t('admin.accounts.gemini.oauthType.customDesc')
{{ t('admin.accounts.oauth.gemini.aiStudioNotConfiguredShort') }} }}
</span> </span>
<div </div>
class="pointer-events-none absolute left-0 top-full z-10 mt-2 w-[28rem] rounded-md border border-amber-200 bg-amber-50 px-3 py-2 text-xs text-amber-800 opacity-0 shadow-sm transition-opacity group-hover:opacity-100 dark:border-amber-700 dark:bg-amber-900/40 dark:text-amber-200"
>
{{ t('admin.accounts.oauth.gemini.aiStudioNotConfiguredTip') }}
</div>
</div>
</div>
</button>
</div> </div>
</fieldset> </div>
<OAuthAuthorizationFlow <OAuthAuthorizationFlow
ref="oauthFlowRef" ref="oauthFlowRef"
...@@ -299,7 +234,6 @@ const oauthFlowRef = ref<OAuthFlowExposed | null>(null) ...@@ -299,7 +234,6 @@ const oauthFlowRef = ref<OAuthFlowExposed | null>(null)
// State // State
const addMethod = ref<AddMethod>('oauth') const addMethod = ref<AddMethod>('oauth')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist') const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist')
const geminiAIStudioOAuthEnabled = ref(false)
// Computed - check platform // Computed - check platform
const isOpenAI = computed(() => props.account?.platform === 'openai') const isOpenAI = computed(() => props.account?.platform === 'openai')
...@@ -367,14 +301,6 @@ watch( ...@@ -367,14 +301,6 @@ watch(
? 'ai_studio' ? 'ai_studio'
: 'code_assist' : 'code_assist'
} }
if (isGemini.value) {
geminiOAuth.getCapabilities().then((caps) => {
geminiAIStudioOAuthEnabled.value = !!caps?.ai_studio_oauth_enabled
if (!geminiAIStudioOAuthEnabled.value && geminiOAuthType.value === 'ai_studio') {
geminiOAuthType.value = 'code_assist'
}
})
}
} else { } else {
resetState() resetState()
} }
...@@ -385,7 +311,6 @@ watch( ...@@ -385,7 +311,6 @@ watch(
const resetState = () => { const resetState = () => {
addMethod.value = 'oauth' addMethod.value = 'oauth'
geminiOAuthType.value = 'code_assist' geminiOAuthType.value = 'code_assist'
geminiAIStudioOAuthEnabled.value = false
claudeOAuth.resetState() claudeOAuth.resetState()
openaiOAuth.resetState() openaiOAuth.resetState()
geminiOAuth.resetState() geminiOAuth.resetState()
...@@ -393,14 +318,6 @@ const resetState = () => { ...@@ -393,14 +318,6 @@ const resetState = () => {
oauthFlowRef.value?.reset() oauthFlowRef.value?.reset()
} }
const handleSelectGeminiOAuthType = (oauthType: 'code_assist' | 'google_one' | 'ai_studio') => {
if (oauthType === 'ai_studio' && !geminiAIStudioOAuthEnabled.value) {
appStore.showError(t('admin.accounts.oauth.gemini.aiStudioNotConfigured'))
return
}
geminiOAuthType.value = oauthType
}
const handleClose = () => { const handleClose = () => {
emit('close') emit('close')
} }
......
<template> <template>
<div v-if="selectedIds.length > 0" class="mb-4 flex items-center justify-between p-3 bg-primary-50 rounded-lg"> <div v-if="selectedIds.length > 0" class="mb-4 flex items-center justify-between p-3 bg-primary-50 rounded-lg dark:bg-primary-900/20">
<span class="text-sm font-medium">{{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }}</span> <div class="flex flex-wrap items-center gap-2">
<span class="text-sm font-medium text-primary-900 dark:text-primary-100">
{{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }}
</span>
<button
@click="$emit('select-page')"
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
>
{{ t('admin.accounts.bulkActions.selectCurrentPage') }}
</button>
<span class="text-gray-300 dark:text-primary-800"></span>
<button
@click="$emit('clear')"
class="text-xs font-medium text-primary-700 hover:text-primary-800 dark:text-primary-300 dark:hover:text-primary-200"
>
{{ t('admin.accounts.bulkActions.clear') }}
</button>
</div>
<div class="flex gap-2"> <div class="flex gap-2">
<button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button> <button @click="$emit('delete')" class="btn btn-danger btn-sm">{{ t('admin.accounts.bulkActions.delete') }}</button>
<button @click="$emit('toggle-schedulable', true)" class="btn btn-success btn-sm">{{ t('admin.accounts.bulkActions.enableScheduling') }}</button>
<button @click="$emit('toggle-schedulable', false)" class="btn btn-warning btn-sm">{{ t('admin.accounts.bulkActions.disableScheduling') }}</button>
<button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button> <button @click="$emit('edit')" class="btn btn-primary btn-sm">{{ t('admin.accounts.bulkActions.edit') }}</button>
</div> </div>
</div> </div>
...@@ -10,5 +29,5 @@ ...@@ -10,5 +29,5 @@
<script setup lang="ts"> <script setup lang="ts">
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
defineProps(['selectedIds']); defineEmits(['delete', 'edit']); const { t } = useI18n() defineProps(['selectedIds']); defineEmits(['delete', 'edit', 'clear', 'select-page', 'toggle-schedulable']); const { t } = useI18n()
</script> </script>
\ No newline at end of file
...@@ -73,111 +73,48 @@ ...@@ -73,111 +73,48 @@
</div> </div>
</fieldset> </fieldset>
<!-- Gemini OAuth Type Selection --> <!-- Gemini OAuth Type Display (read-only) -->
<fieldset v-if="isGemini" class="border-0 p-0"> <div v-if="isGemini" class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-600 dark:bg-dark-700">
<legend class="input-label">{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}</legend> <div class="mb-2 text-sm font-medium text-gray-700 dark:text-gray-300">
<div class="mt-2 grid grid-cols-3 gap-3"> {{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
<button </div>
type="button" <div class="flex items-center gap-3">
@click="handleSelectGeminiOAuthType('google_one')" <div
:class="[ :class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all', 'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one' geminiOAuthType === 'google_one'
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20' ? 'bg-purple-500 text-white'
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700' : geminiOAuthType === 'code_assist'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'google_one'
? 'bg-purple-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="user" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">Google One</span>
<span class="text-xs text-gray-500 dark:text-gray-400">个人账号</span>
</div>
</button>
<button
type="button"
@click="handleSelectGeminiOAuthType('code_assist')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
geminiOAuthType === 'code_assist'
? 'border-blue-500 bg-blue-50 dark:bg-blue-900/20'
: 'border-gray-200 hover:border-blue-300 dark:border-dark-600 dark:hover:border-blue-700'
]"
>
<div
:class="[
'flex h-8 w-8 items-center justify-center rounded-lg',
geminiOAuthType === 'code_assist'
? 'bg-blue-500 text-white' ? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' : 'bg-amber-500 text-white'
]"
>
<Icon name="cloud" size="sm" />
</div>
<div class="min-w-0">
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.gemini.oauthType.builtInTitle') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.gemini.oauthType.builtInDesc') }}
</span>
</div>
</button>
<button
type="button"
:disabled="!geminiAIStudioOAuthEnabled"
@click="handleSelectGeminiOAuthType('ai_studio')"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
!geminiAIStudioOAuthEnabled ? 'cursor-not-allowed opacity-60' : '',
geminiOAuthType === 'ai_studio'
? 'border-purple-500 bg-purple-50 dark:bg-purple-900/20'
: 'border-gray-200 hover:border-purple-300 dark:border-dark-600 dark:hover:border-purple-700'
]" ]"
> >
<div <Icon v-if="geminiOAuthType === 'google_one'" name="user" size="sm" />
:class="[ <Icon v-else-if="geminiOAuthType === 'code_assist'" name="cloud" size="sm" />
'flex h-8 w-8 items-center justify-center rounded-lg', <Icon v-else name="sparkles" size="sm" />
geminiOAuthType === 'ai_studio' </div>
? 'bg-purple-500 text-white' <div>
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400' <span class="block text-sm font-medium text-gray-900 dark:text-white">
]" {{
> geminiOAuthType === 'google_one'
<Icon name="sparkles" size="sm" /> ? 'Google One'
</div> : geminiOAuthType === 'code_assist'
<div class="min-w-0"> ? t('admin.accounts.gemini.oauthType.builtInTitle')
<span class="block text-sm font-medium text-gray-900 dark:text-white"> : t('admin.accounts.gemini.oauthType.customTitle')
{{ t('admin.accounts.gemini.oauthType.customTitle') }} }}
</span> </span>
<span class="text-xs text-gray-500 dark:text-gray-400"> <span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.gemini.oauthType.customDesc') }} {{
</span> geminiOAuthType === 'google_one'
<div v-if="!geminiAIStudioOAuthEnabled" class="group relative mt-1 inline-block"> ? '个人账号'
<span : geminiOAuthType === 'code_assist'
class="rounded bg-amber-100 px-2 py-0.5 text-xs text-amber-700 dark:bg-amber-900/30 dark:text-amber-300" ? t('admin.accounts.gemini.oauthType.builtInDesc')
> : t('admin.accounts.gemini.oauthType.customDesc')
{{ t('admin.accounts.oauth.gemini.aiStudioNotConfiguredShort') }} }}
</span> </span>
<div </div>
class="pointer-events-none absolute left-0 top-full z-10 mt-2 w-[28rem] rounded-md border border-amber-200 bg-amber-50 px-3 py-2 text-xs text-amber-800 opacity-0 shadow-sm transition-opacity group-hover:opacity-100 dark:border-amber-700 dark:bg-amber-900/40 dark:text-amber-200"
>
{{ t('admin.accounts.oauth.gemini.aiStudioNotConfiguredTip') }}
</div>
</div>
</div>
</button>
</div> </div>
</fieldset> </div>
<OAuthAuthorizationFlow <OAuthAuthorizationFlow
ref="oauthFlowRef" ref="oauthFlowRef"
...@@ -297,7 +234,6 @@ const oauthFlowRef = ref<OAuthFlowExposed | null>(null) ...@@ -297,7 +234,6 @@ const oauthFlowRef = ref<OAuthFlowExposed | null>(null)
// State // State
const addMethod = ref<AddMethod>('oauth') const addMethod = ref<AddMethod>('oauth')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist') const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist')
const geminiAIStudioOAuthEnabled = ref(false)
// Computed - check platform // Computed - check platform
const isOpenAI = computed(() => props.account?.platform === 'openai') const isOpenAI = computed(() => props.account?.platform === 'openai')
...@@ -365,14 +301,6 @@ watch( ...@@ -365,14 +301,6 @@ watch(
? 'ai_studio' ? 'ai_studio'
: 'code_assist' : 'code_assist'
} }
if (isGemini.value) {
geminiOAuth.getCapabilities().then((caps) => {
geminiAIStudioOAuthEnabled.value = !!caps?.ai_studio_oauth_enabled
if (!geminiAIStudioOAuthEnabled.value && geminiOAuthType.value === 'ai_studio') {
geminiOAuthType.value = 'code_assist'
}
})
}
} else { } else {
resetState() resetState()
} }
...@@ -383,7 +311,6 @@ watch( ...@@ -383,7 +311,6 @@ watch(
const resetState = () => { const resetState = () => {
addMethod.value = 'oauth' addMethod.value = 'oauth'
geminiOAuthType.value = 'code_assist' geminiOAuthType.value = 'code_assist'
geminiAIStudioOAuthEnabled.value = false
claudeOAuth.resetState() claudeOAuth.resetState()
openaiOAuth.resetState() openaiOAuth.resetState()
geminiOAuth.resetState() geminiOAuth.resetState()
...@@ -391,14 +318,6 @@ const resetState = () => { ...@@ -391,14 +318,6 @@ const resetState = () => {
oauthFlowRef.value?.reset() oauthFlowRef.value?.reset()
} }
const handleSelectGeminiOAuthType = (oauthType: 'code_assist' | 'google_one' | 'ai_studio') => {
if (oauthType === 'ai_studio' && !geminiAIStudioOAuthEnabled.value) {
appStore.showError(t('admin.accounts.oauth.gemini.aiStudioNotConfigured'))
return
}
geminiOAuthType.value = oauthType
}
const handleClose = () => { const handleClose = () => {
emit('close') emit('close')
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment