Commit 538ae31a authored by 陈曦's avatar 陈曦
Browse files

merge v0.1.121 and fixed conflict

parents 74828a7c 48912014
Pipeline #82338 passed with stage
in 17 seconds
......@@ -122,7 +122,7 @@ scripts
.code-review-state
#openspec/
code-reviews/
#AGENTS.md
AGENTS.md
backend/cmd/server/server
deploy/docker-compose.override.yml
.gocache/
......
......@@ -101,6 +101,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
<td>Thanks to Bestproxy for sponsoring this project! <a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.</td>
</tr>
<tr>
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
<td>Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line.
Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information.
Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.</td>
</tr>
</table>
## Ecosystem
......
......@@ -100,6 +100,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
<td>感谢 Bestproxy 赞助了本项目!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。</td>
</tr>
<tr>
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
<td>感谢 PatewayAI 赞助了本项目!PatewayAI 是一家面向重度 AI 开发者、专注官方直连的高品质模型 API 中转服务商。提供 Claude 全系列与 Codex 系列模型,100% 官方源直供,不掺假不注水,欢迎检验。计费透明,Token 级账单可逐笔核验。
时支持企业级高并发,并为企业客户提供了专业的管理平台,企业客户可签订正式合同并开具发票,更多详情进入官网获取联系方式。
在通过 <a href="https://pateway.ai/?ch=1tsfr51">此链接</a> 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。</td>
</tr>
</table>
## 生态项目
......
......@@ -100,6 +100,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
<td>Bestproxy のご支援に感謝します!<a href="https://bestproxy.com/?keyword=a2e8iuol">Bestproxy</a> は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。</td>
</tr>
<tr>
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
<td>PatewayAI のご支援に感謝します!PatewayAI は、ヘビーAI開発者向けに公式直結を重視した高品質モデルAPIリレーサービスプロバイダーです。Claude 全シリーズおよび Codex シリーズモデルを提供し、100%公式ソースから直接供給 — 偽りなし、水増しなし、検証歓迎。課金は完全透明で、トークン単位の請求書を1件ずつ監査可能です。
エンタープライズ級の高同時接続にも対応し、法人顧客向けに専用管理プラットフォームを提供しています。法人顧客は正式な契約を締結し、請求書の発行が可能です。詳細は公式サイトでお問い合わせください。
<a href="https://pateway.ai/?ch=1tsfr51">こちらのリンク</a>から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。</td>
</tr>
</table>
## エコシステム
......
......@@ -65,7 +65,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
......@@ -145,13 +145,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
......@@ -178,7 +179,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
......@@ -186,7 +186,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
......
......@@ -26,11 +26,12 @@ const (
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI)
)
// Redeem type constants
......
......@@ -98,7 +98,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
......@@ -117,7 +117,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
......@@ -134,19 +134,29 @@ type UpdateAccountRequest struct {
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
LoadFactor *int `json:"load_factor"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
AccountIDs []int64 `json:"account_ids"`
Filters *BulkUpdateAccountFilters `json:"filters"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
LoadFactor *int `json:"load_factor"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
type BulkUpdateAccountFilters struct {
Platform string `json:"platform"`
Type string `json:"type"`
Status string `json:"status"`
Group string `json:"group"`
Search string `json:"search"`
PrivacyMode string `json:"privacy_mode"`
}
// CheckMixedChannelRequest represents check mixed channel risk request
......@@ -1369,6 +1379,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
if len(req.AccountIDs) == 0 && req.Filters == nil {
response.BadRequest(c, "account_ids or filters is required")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
......@@ -1394,6 +1408,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
Filters: toServiceBulkUpdateAccountFilters(req.Filters),
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
......@@ -1429,6 +1444,20 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.Success(c, result)
}
func toServiceBulkUpdateAccountFilters(filters *BulkUpdateAccountFilters) *service.BulkUpdateAccountFilters {
if filters == nil {
return nil
}
return &service.BulkUpdateAccountFilters{
Platform: filters.Platform,
Type: filters.Type,
Status: filters.Status,
Group: filters.Group,
Search: filters.Search,
PrivacyMode: filters.PrivacyMode,
}
}
// ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL
......
......@@ -196,3 +196,29 @@ func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
require.Equal(t, float64(2), data["success"])
require.Equal(t, float64(0), data["failed"])
}
func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"filters": map[string]any{
"platform": "openai",
"type": "oauth",
"status": "active",
"group": "12",
"privacy_mode": "blocked",
"search": "bulk-target",
},
"schedulable": true,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
}
......@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
......@@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}
// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
t.Run("nil input returns nil", func(t *testing.T) {
require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
})
t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: "",
Action: "filter",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.NotNil(t, out)
require.Len(t, out.Rules, 1)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
require.Equal(t, "all", out.Rules[0].ServiceTier)
})
t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: " ",
Action: "pass",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
})
t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: "PRIORITY",
Action: "filter",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
})
t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{
{ServiceTier: "priority", Action: "filter", Scope: "all"},
{ServiceTier: "flex", Action: "block", Scope: "oauth"},
{ServiceTier: "all", Action: "pass", Scope: "apikey"},
},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Len(t, out.Rules, 3)
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
})
}
......@@ -565,6 +565,22 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return nil, service.ErrAPIKeyNotFound
}
func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*service.APIKey, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
s.apiKeys[i].Usage5h = 0
s.apiKeys[i].Usage1d = 0
s.apiKeys[i].Usage7d = 0
s.apiKeys[i].Window5hStart = nil
s.apiKeys[i].Window1dStart = nil
s.apiKeys[i].Window7dStart = nil
k := s.apiKeys[i]
return &k, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
func (s *stubAdminService) AdminSetCaptureRequests(ctx context.Context, keyID int64, enabled bool) (*service.APIKey, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
......
......@@ -27,11 +27,13 @@ func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandle
}
}
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key.
type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // true=重置 5h/1d/7d 限速用量
}
// UpdateGroup handles updating an API key's admin-managed fields.
// SetCaptureRequests 开启或关闭指定 API Key 的请求体捕获,并立即失效认证缓存。
// PUT /api/v1/admin/api-keys/:id/capture-requests
func (h *AdminAPIKeyHandler) SetCaptureRequests(c *gin.Context) {
......@@ -74,11 +76,23 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
return
}
var resetKey *service.APIKey
if req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage {
resetKey, err = h.adminService.AdminResetAPIKeyRateLimitUsage(c.Request.Context(), keyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
}
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if resetKey != nil && req.GroupID == nil {
result.APIKey = resetKey
}
resp := struct {
APIKey *dto.APIKey `json:"api_key"`
......
......@@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -117,6 +118,45 @@ func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
require.Nil(t, resp.Data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_ResetRateLimitUsage(t *testing.T) {
svc := newStubAdminService()
now := time.Now()
svc.apiKeys[0].Usage5h = 1.2
svc.apiKeys[0].Usage1d = 3.4
svc.apiKeys[0].Usage7d = 5.6
svc.apiKeys[0].Window5hStart = &now
svc.apiKeys[0].Window1dStart = &now
svc.apiKeys[0].Window7dStart = &now
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"reset_rate_limit_usage":true}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data struct {
APIKey struct {
Usage5h float64 `json:"usage_5h"`
Usage1d float64 `json:"usage_1d"`
Usage7d float64 `json:"usage_7d"`
Window5hStart *time.Time `json:"window_5h_start"`
Window1dStart *time.Time `json:"window_1d_start"`
Window7dStart *time.Time `json:"window_7d_start"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Zero(t, resp.Data.APIKey.Usage5h)
require.Zero(t, resp.Data.APIKey.Usage1d)
require.Zero(t, resp.Data.APIKey.Usage7d)
require.Nil(t, resp.Data.APIKey.Window5hStart)
require.Nil(t, resp.Data.APIKey.Window1dStart)
require.Nil(t, resp.Data.APIKey.Window7dStart)
}
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
......
......@@ -209,6 +209,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
......@@ -248,9 +249,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
AffiliateEnabled: settings.AffiliateEnabled,
}
// OpenAI fast policy (stored under a dedicated setting key)
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
if s == nil {
return nil
}
rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
for i, r := range s.Rules {
rules[i] = dto.OpenAIFastPolicyRule(r)
}
return &dto.OpenAIFastPolicySettings{Rules: rules}
}
// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
//
// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为
// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
// service.SetOpenAIFastPolicySettings 负责合法值校验。
func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
if s == nil {
return nil
}
rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
for i, r := range s.Rules {
rules[i] = service.OpenAIFastPolicyRule(r)
tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
if tier == "" {
tier = service.OpenAIFastTierAny
}
rules[i].ServiceTier = tier
}
return &service.OpenAIFastPolicySettings{Rules: rules}
}
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
......@@ -399,9 +442,10 @@ type UpdateSettingsRequest struct {
BackendModeEnabled bool `json:"backend_mode_enabled"`
// Gateway forwarding behavior
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
// Payment visible method routing
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
......@@ -452,6 +496,9 @@ type UpdateSettingsRequest struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled *bool `json:"affiliate_enabled"`
// OpenAI fast/flex policy (optional, only updated when provided)
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
// UpdateSettings 更新系统设置
......@@ -1228,6 +1275,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
EnableAnthropicCacheTTL1hInjection: func() bool {
if req.EnableAnthropicCacheTTL1hInjection != nil {
return *req.EnableAnthropicCacheTTL1hInjection
}
return previousSettings.EnableAnthropicCacheTTL1hInjection
}(),
PaymentVisibleMethodAlipaySource: func() string {
if req.PaymentVisibleMethodAlipaySource != nil {
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
......@@ -1350,6 +1403,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
// Update OpenAI fast policy (stored under dedicated key, only when provided).
if req.OpenAIFastPolicySettings != nil {
if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
response.BadRequest(c, err.Error())
return
}
}
// Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe).
if h.paymentConfigService != nil && hasPaymentFields(req) {
......@@ -1517,6 +1578,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
......@@ -1555,6 +1617,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
AffiliateEnabled: updatedSettings.AffiliateEnabled,
}
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
......@@ -1891,6 +1958,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection {
changed = append(changed, "enable_anthropic_cache_ttl_1h_injection")
}
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
changed = append(changed, "payment_visible_method_alipay_source")
}
......
......@@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.
}
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
if s.values != nil {
if value, ok := s.values[key]; ok {
return value, nil
}
}
return "", nil
}
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
......
......@@ -142,9 +142,10 @@ type SystemSettings struct {
BackendModeEnabled bool `json:"backend_mode_enabled"`
// Gateway forwarding behavior
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"`
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"`
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
......@@ -198,6 +199,9 @@ type SystemSettings struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled bool `json:"affiliate_enabled"`
// OpenAI fast/flex policy
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
type DefaultSubscriptionSetting struct {
......@@ -294,6 +298,22 @@ type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
type OpenAIFastPolicyRule struct {
ServiceTier string `json:"service_tier"`
Action string `json:"action"`
Scope string `json:"scope"`
ErrorMessage string `json:"error_message,omitempty"`
ModelWhitelist []string `json:"model_whitelist,omitempty"`
FallbackAction string `json:"fallback_action,omitempty"`
FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
}
// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem {
......
......@@ -288,6 +288,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// [DEBUG-STICKY] 打印会话 hash 生成结果
reqLog.Info("sticky.session_hash_generated",
zap.String("session_hash", sessionHash),
zap.String("metadata_user_id_raw", parsedReq.MetadataUserID),
)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
platform := ""
if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
......@@ -304,6 +310,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
// [DEBUG-STICKY] 打印粘性会话查询结果
reqLog.Info("sticky.cache_lookup",
zap.String("session_key", sessionKey),
zap.Int64("bound_account_id", sessionBoundAccountID),
)
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
......@@ -312,6 +323,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
} else {
reqLog.Info("sticky.no_session_key", zap.String("session_hash", sessionHash))
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
......@@ -567,6 +580,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
reqLog.Info("sticky.selecting_account",
zap.String("session_key", sessionKey),
zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
zap.Bool("has_bound_session", hasBoundSession),
zap.Int("failed_account_count", len(fs.FailedAccountIDs)),
)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
......@@ -600,6 +619,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// [DEBUG-STICKY] 打印账号选择结果
reqLog.Info("sticky.account_selected",
zap.Int64("selected_account_id", account.ID),
zap.String("account_name", account.Name),
zap.Bool("slot_acquired", selection.Acquired),
zap.Bool("has_wait_plan", selection.WaitPlan != nil),
zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
zap.Bool("sticky_honored", sessionBoundAccountID > 0 && sessionBoundAccountID == account.ID),
)
// 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
......@@ -666,6 +695,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// Slot acquired: no longer waiting in queue.
releaseWait()
reqLog.Info("sticky.bind_after_wait",
zap.String("session_key", sessionKey),
zap.Int64("account_id", account.ID),
)
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
......@@ -860,6 +893,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
// 绑定粘性会话(成功转发后绑定/刷新)
// - 无现有绑定(首次请求):创建绑定
// - 选中账号与粘性账号一致:刷新 TTL
// - 粘性账号因负载/RPM 被跳过、选中了其他账号:不覆盖原绑定,
// 下次请求粘性账号恢复后仍可命中
if sessionKey != "" && (sessionBoundAccountID == 0 || sessionBoundAccountID == account.ID) {
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
......
......@@ -50,6 +50,9 @@ func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
return true, nil
}
func (f *fakeSchedulerCache) UnlockBucket(_ context.Context, _ service.SchedulerBucket) error {
return nil
}
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
......
......@@ -991,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"])
fn, ok := tc["function"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "get_weather", fn["name"])
assert.Equal(t, "get_weather", tc["name"])
assert.NotContains(t, tc, "function")
}
func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) {
req := &ResponsesRequest{
Model: "gpt-5.2",
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`),
}
resp, err := ResponsesToAnthropicRequest(req)
require.NoError(t, err)
var tc map[string]string
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "tool", tc["type"])
assert.Equal(t, "get_weather", tc["name"])
}
func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) {
req := &ResponsesRequest{
Model: "gpt-5.2",
Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`),
}
resp, err := ResponsesToAnthropicRequest(req)
require.NoError(t, err)
var tc map[string]string
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "tool", tc["type"])
assert.Equal(t, "get_weather", tc["name"])
}
// ---------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment