Commit fd43be8d authored by yangjianbo's avatar yangjianbo
Browse files

merge: 合并 main 分支到 test,解决 config 和 modelWhitelist 冲突



- config.go: 保留 Sora 配置,合入 SubscriptionCache 配置
- useModelWhitelist.ts: 同时保留 soraModels 和 antigravityModels
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parents 792bef61 836ba14b
......@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type stubOpenAIAccountRepo struct {
......@@ -204,6 +205,22 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i
return nil
}
func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
......@@ -1066,6 +1083,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
}
}
func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 1, result.usage.InputTokens)
require.Equal(t, 2, result.usage.OutputTokens)
require.Equal(t, 3, result.usage.CacheReadInputTokens)
}
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
......@@ -1149,3 +1203,226 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
t.Fatalf("expected non-allowlisted host to fail")
}
}
// ==================== P1-08 修复:model 替换性能优化测试 ====================
func TestReplaceModelInSSELine(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
line string
from string
to string
expected string
}{
{
name: "顶层 model 字段替换",
line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
from: "gpt-4o",
to: "my-custom-model",
expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`,
},
{
name: "嵌套 response.model 替换",
line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`,
},
{
name: "model 不匹配时不替换",
line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
},
{
name: "无 model 字段时不替换",
line: `data: {"id":"chatcmpl-123","choices":[]}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"chatcmpl-123","choices":[]}`,
},
{
name: "空 data 行",
line: `data: `,
from: "gpt-4o",
to: "my-model",
expected: `data: `,
},
{
name: "[DONE] 行",
line: `data: [DONE]`,
from: "gpt-4o",
to: "my-model",
expected: `data: [DONE]`,
},
{
name: "非 data: 前缀行",
line: `event: message`,
from: "gpt-4o",
to: "my-model",
expected: `event: message`,
},
{
name: "非法 JSON 不替换",
line: `data: {invalid json}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {invalid json}`,
},
{
name: "无空格 data: 格式",
line: `data:{"id":"x","model":"gpt-4o"}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"x","model":"my-model"}`,
},
{
name: "model 名含特殊字符",
line: `data: {"model":"org/model-v2.1-beta"}`,
from: "org/model-v2.1-beta",
to: "custom/alias",
expected: `data: {"model":"custom/alias"}`,
},
{
name: "空行",
line: "",
from: "gpt-4o",
to: "my-model",
expected: "",
},
{
name: "保持其他字段不变",
line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
from: "gpt-4o",
to: "alias",
expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
},
{
name: "顶层优先于嵌套:同时存在两个 model",
line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`,
from: "gpt-4o",
to: "replaced",
expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to)
require.Equal(t, tt.expected, got)
})
}
}
func TestReplaceModelInSSEBody(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
body string
from string
to string
expected string
}{
{
name: "多行 SSE body 替换",
body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
from: "gpt-4o",
to: "alias",
expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
},
{
name: "无需替换的 body",
body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
from: "gpt-4o",
to: "alias",
expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
},
{
name: "混合 event 和 data 行",
body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n",
from: "gpt-4o",
to: "alias",
expected: "event: message\ndata: {\"model\":\"alias\"}\n\n",
},
{
name: "空 body",
body: "",
from: "gpt-4o",
to: "alias",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to)
require.Equal(t, tt.expected, got)
})
}
}
func TestReplaceModelInResponseBody(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
body string
from string
to string
expected string
}{
{
name: "替换顶层 model",
body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`,
},
{
name: "model 不匹配不替换",
body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
},
{
name: "无 model 字段不替换",
body: `{"id":"chatcmpl-123","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","choices":[]}`,
},
{
name: "非法 JSON 返回原值",
body: `not json`,
from: "gpt-4o",
to: "alias",
expected: `not json`,
},
{
name: "空 body 返回原值",
body: ``,
from: "gpt-4o",
to: "alias",
expected: ``,
},
{
name: "保持嵌套结构不变",
body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
from: "gpt-4o",
to: "alias",
expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to)
require.Equal(t, tt.expected, string(got))
})
}
}
......@@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
}
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
if acc.Platform != "" {
......
......@@ -255,3 +255,142 @@ func (s *OpsService) GetConcurrencyStats(
return platform, group, account, &collectedAt, nil
}
// listAllActiveUsersForOps returns all active users with their concurrency settings.
func (s *OpsService) listAllActiveUsersForOps(ctx context.Context) ([]User, error) {
if s == nil || s.userRepo == nil {
return []User{}, nil
}
out := make([]User, 0, 128)
page := 1
for {
users, pageInfo, err := s.userRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page,
PageSize: opsAccountsPageSize,
}, UserListFilters{
Status: StatusActive,
})
if err != nil {
return nil, err
}
if len(users) == 0 {
break
}
out = append(out, users...)
if pageInfo != nil && int64(len(out)) >= pageInfo.Total {
break
}
if len(users) < opsAccountsPageSize {
break
}
page++
if page > 10_000 {
log.Printf("[Ops] listAllActiveUsersForOps: aborting after too many pages")
break
}
}
return out, nil
}
// getUsersLoadMapBestEffort returns user load info for the given users.
func (s *OpsService) getUsersLoadMapBestEffort(ctx context.Context, users []User) map[int64]*UserLoadInfo {
if s == nil || s.concurrencyService == nil {
return map[int64]*UserLoadInfo{}
}
if len(users) == 0 {
return map[int64]*UserLoadInfo{}
}
// De-duplicate IDs (and keep the max concurrency to avoid under-reporting).
unique := make(map[int64]int, len(users))
for _, u := range users {
if u.ID <= 0 {
continue
}
if prev, ok := unique[u.ID]; !ok || u.Concurrency > prev {
unique[u.ID] = u.Concurrency
}
}
batch := make([]UserWithConcurrency, 0, len(unique))
for id, maxConc := range unique {
batch = append(batch, UserWithConcurrency{
ID: id,
MaxConcurrency: maxConc,
})
}
out := make(map[int64]*UserLoadInfo, len(batch))
for i := 0; i < len(batch); i += opsConcurrencyBatchChunkSize {
end := i + opsConcurrencyBatchChunkSize
if end > len(batch) {
end = len(batch)
}
part, err := s.concurrencyService.GetUsersLoadBatch(ctx, batch[i:end])
if err != nil {
// Best-effort: return zeros rather than failing the ops UI.
log.Printf("[Ops] GetUsersLoadBatch failed: %v", err)
continue
}
for k, v := range part {
out[k] = v
}
}
return out
}
// GetUserConcurrencyStats returns real-time concurrency usage for all active users.
func (s *OpsService) GetUserConcurrencyStats(ctx context.Context) (map[int64]*UserConcurrencyInfo, *time.Time, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, nil, err
}
users, err := s.listAllActiveUsersForOps(ctx)
if err != nil {
return nil, nil, err
}
collectedAt := time.Now()
loadMap := s.getUsersLoadMapBestEffort(ctx, users)
result := make(map[int64]*UserConcurrencyInfo)
for _, u := range users {
if u.ID <= 0 {
continue
}
load := loadMap[u.ID]
currentInUse := int64(0)
waiting := int64(0)
if load != nil {
currentInUse = int64(load.CurrentConcurrency)
waiting = int64(load.WaitingCount)
}
// Skip users with no concurrency activity
if currentInUse == 0 && waiting == 0 {
continue
}
info := &UserConcurrencyInfo{
UserID: u.ID,
UserEmail: u.Email,
Username: u.Username,
CurrentInUse: currentInUse,
MaxCapacity: int64(u.Concurrency),
WaitingInQueue: waiting,
}
if info.MaxCapacity > 0 {
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
}
result[u.ID] = info
}
return result, &collectedAt, nil
}
......@@ -37,6 +37,17 @@ type AccountConcurrencyInfo struct {
WaitingInQueue int64 `json:"waiting_in_queue"`
}
// UserConcurrencyInfo represents real-time concurrency usage for a single user.
type UserConcurrencyInfo struct {
UserID int64 `json:"user_id"`
UserEmail string `json:"user_email"`
Username string `json:"username"`
CurrentInUse int64 `json:"current_in_use"`
MaxCapacity int64 `json:"max_capacity"`
LoadPercentage float64 `json:"load_percentage"`
WaitingInQueue int64 `json:"waiting_in_queue"`
}
// PlatformAvailability aggregates account availability by platform.
type PlatformAvailability struct {
Platform string `json:"platform"`
......
......@@ -576,7 +576,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq
action = "streamGenerateContent"
}
if account.Platform == PlatformAntigravity {
_, err = s.antigravityGatewayService.ForwardGemini(ctx, c, account, modelName, action, errorLog.Stream, body)
_, err = s.antigravityGatewayService.ForwardGemini(ctx, c, account, modelName, action, errorLog.Stream, body, false)
} else {
_, err = s.geminiCompatService.ForwardNative(ctx, c, account, modelName, action, errorLog.Stream, body)
}
......@@ -586,7 +586,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq
if s.antigravityGatewayService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "antigravity gateway service not available"}
}
_, err = s.antigravityGatewayService.Forward(ctx, c, account, body)
_, err = s.antigravityGatewayService.Forward(ctx, c, account, body, false)
case PlatformGemini:
if s.geminiCompatService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gemini gateway service not available"}
......
......@@ -27,6 +27,7 @@ type OpsService struct {
cfg *config.Config
accountRepo AccountRepository
userRepo UserRepository
// getAccountAvailability is a unit-test hook for overriding account availability lookup.
getAccountAvailability func(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error)
......@@ -43,6 +44,7 @@ func NewOpsService(
settingRepo SettingRepository,
cfg *config.Config,
accountRepo AccountRepository,
userRepo UserRepository,
concurrencyService *ConcurrencyService,
gatewayService *GatewayService,
openAIGatewayService *OpenAIGatewayService,
......@@ -55,6 +57,7 @@ func NewOpsService(
cfg: cfg,
accountRepo: accountRepo,
userRepo: userRepo,
concurrencyService: concurrencyService,
gatewayService: gatewayService,
......@@ -424,6 +427,26 @@ func isSensitiveKey(key string) bool {
return false
}
// Token 计数 / 预算字段不是凭据,应保留用于排错。
// 白名单保持尽量窄,避免误把真实敏感信息"反脱敏"。
switch k {
case "max_tokens",
"max_output_tokens",
"max_input_tokens",
"max_completion_tokens",
"max_tokens_to_sample",
"budget_tokens",
"prompt_tokens",
"completion_tokens",
"input_tokens",
"output_tokens",
"total_tokens",
"token_count",
"cache_creation_input_tokens",
"cache_read_input_tokens":
return false
}
// Exact matches (common credential fields).
switch k {
case "authorization",
......@@ -566,7 +589,18 @@ func trimArrayField(root map[string]any, field string, maxBytes int) (map[string
func shrinkToEssentials(root map[string]any) map[string]any {
out := make(map[string]any)
for _, key := range []string{"model", "stream", "max_tokens", "temperature", "top_p", "top_k"} {
for _, key := range []string{
"model",
"stream",
"max_tokens",
"max_output_tokens",
"max_input_tokens",
"max_completion_tokens",
"thinking",
"temperature",
"top_p",
"top_k",
} {
if v, ok := root[key]; ok {
out[key] = v
}
......
package service
import (
"encoding/json"
"testing"
)
func TestIsSensitiveKey_TokenBudgetKeysNotRedacted(t *testing.T) {
t.Parallel()
for _, key := range []string{
"max_tokens",
"max_output_tokens",
"max_input_tokens",
"max_completion_tokens",
"max_tokens_to_sample",
"budget_tokens",
"prompt_tokens",
"completion_tokens",
"input_tokens",
"output_tokens",
"total_tokens",
"token_count",
} {
if isSensitiveKey(key) {
t.Fatalf("expected key %q to NOT be treated as sensitive", key)
}
}
for _, key := range []string{
"authorization",
"Authorization",
"access_token",
"refresh_token",
"id_token",
"session_token",
"token",
"client_secret",
"private_key",
"signature",
} {
if !isSensitiveKey(key) {
t.Fatalf("expected key %q to be treated as sensitive", key)
}
}
}
func TestSanitizeAndTrimRequestBody_PreservesTokenBudgetFields(t *testing.T) {
t.Parallel()
raw := []byte(`{"model":"claude-3","max_tokens":123,"thinking":{"type":"enabled","budget_tokens":456},"access_token":"abc","messages":[{"role":"user","content":"hi"}]}`)
out, _, _ := sanitizeAndTrimRequestBody(raw, 10*1024)
if out == "" {
t.Fatalf("expected non-empty sanitized output")
}
var decoded map[string]any
if err := json.Unmarshal([]byte(out), &decoded); err != nil {
t.Fatalf("unmarshal sanitized output: %v", err)
}
if got, ok := decoded["max_tokens"].(float64); !ok || got != 123 {
t.Fatalf("expected max_tokens=123, got %#v", decoded["max_tokens"])
}
thinking, ok := decoded["thinking"].(map[string]any)
if !ok || thinking == nil {
t.Fatalf("expected thinking object to be preserved, got %#v", decoded["thinking"])
}
if got, ok := thinking["budget_tokens"].(float64); !ok || got != 456 {
t.Fatalf("expected thinking.budget_tokens=456, got %#v", thinking["budget_tokens"])
}
if got := decoded["access_token"]; got != "[REDACTED]" {
t.Fatalf("expected access_token to be redacted, got %#v", got)
}
}
func TestShrinkToEssentials_IncludesThinking(t *testing.T) {
t.Parallel()
root := map[string]any{
"model": "claude-3",
"max_tokens": 100,
"thinking": map[string]any{
"type": "enabled",
"budget_tokens": 200,
},
"messages": []any{
map[string]any{"role": "user", "content": "first"},
map[string]any{"role": "user", "content": "last"},
},
}
out := shrinkToEssentials(root)
if _, ok := out["thinking"]; !ok {
t.Fatalf("expected thinking to be included in essentials: %#v", out)
}
}
......@@ -16,6 +16,7 @@ var (
type ProxyRepository interface {
Create(ctx context.Context, proxy *Proxy) error
GetByID(ctx context.Context, id int64) (*Proxy, error)
ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error)
Update(ctx context.Context, proxy *Proxy) error
Delete(ctx context.Context, id int64) error
......
......@@ -387,14 +387,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 没有重置时间,使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
} else {
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
}
return
}
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
......@@ -407,14 +399,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
if err != nil {
slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
} else {
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
}
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
......@@ -423,15 +407,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
resetAt := time.Unix(ts, 0)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
return
}
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
return
}
// 标记限流状态
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
......@@ -448,17 +423,6 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
}
func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, responseBody []byte) bool {
if account == nil || account.Platform != PlatformAnthropic {
return false
}
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(responseBody)))
if msg == "" {
return false
}
return strings.Contains(msg, "sonnet")
}
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
......
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestFilterByMinPriority(t *testing.T) {
t.Run("empty slice", func(t *testing.T) {
result := filterByMinPriority(nil)
require.Empty(t, result)
})
t.Run("single account", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 5}, loadInfo: &AccountLoadInfo{}},
}
result := filterByMinPriority(accounts)
require.Len(t, result, 1)
require.Equal(t, int64(1), result[0].account.ID)
})
t.Run("multiple accounts same priority", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 3}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, Priority: 3}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, Priority: 3}, loadInfo: &AccountLoadInfo{}},
}
result := filterByMinPriority(accounts)
require.Len(t, result, 3)
})
t.Run("filters to min priority only", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 5}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, Priority: 1}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, Priority: 3}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 4, Priority: 1}, loadInfo: &AccountLoadInfo{}},
}
result := filterByMinPriority(accounts)
require.Len(t, result, 2)
require.Equal(t, int64(2), result[0].account.ID)
require.Equal(t, int64(4), result[1].account.ID)
})
}
func TestFilterByMinLoadRate(t *testing.T) {
t.Run("empty slice", func(t *testing.T) {
result := filterByMinLoadRate(nil)
require.Empty(t, result)
})
t.Run("single account", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
}
result := filterByMinLoadRate(accounts)
require.Len(t, result, 1)
require.Equal(t, int64(1), result[0].account.ID)
})
t.Run("multiple accounts same load rate", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 2}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 3}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
}
result := filterByMinLoadRate(accounts)
require.Len(t, result, 3)
})
t.Run("filters to min load rate only", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 80}},
{account: &Account{ID: 2}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 3}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
{account: &Account{ID: 4}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
result := filterByMinLoadRate(accounts)
require.Len(t, result, 2)
require.Equal(t, int64(2), result[0].account.ID)
require.Equal(t, int64(4), result[1].account.ID)
})
t.Run("zero load rate", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
{account: &Account{ID: 2}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
{account: &Account{ID: 3}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
}
result := filterByMinLoadRate(accounts)
require.Len(t, result, 2)
require.Equal(t, int64(1), result[0].account.ID)
require.Equal(t, int64(3), result[1].account.ID)
})
}
func TestSelectByLRU(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
muchEarlier := now.Add(-2 * time.Hour)
t.Run("empty slice", func(t *testing.T) {
result := selectByLRU(nil, false)
require.Nil(t, result)
})
t.Run("single account", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{}},
}
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.Equal(t, int64(1), result.account.ID)
})
t.Run("selects least recently used", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{}},
}
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.Equal(t, int64(2), result.account.ID)
})
t.Run("nil LastUsedAt preferred over non-nil", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{}},
}
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.Equal(t, int64(2), result.account.ID)
})
t.Run("multiple nil LastUsedAt random selection", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
}
// 多次调用应该随机选择,验证结果都在候选范围内
validIDs := map[int64]bool{1: true, 2: true, 3: true}
for i := 0; i < 10; i++ {
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.True(t, validIDs[result.account.ID], "selected ID should be one of the candidates")
}
})
t.Run("multiple same LastUsedAt random selection", func(t *testing.T) {
sameTime := now
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &sameTime}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: &sameTime}, loadInfo: &AccountLoadInfo{}},
}
// 多次调用应该随机选择
validIDs := map[int64]bool{1: true, 2: true}
for i := 0; i < 10; i++ {
result := selectByLRU(accounts, false)
require.NotNil(t, result)
require.True(t, validIDs[result.account.ID], "selected ID should be one of the candidates")
}
})
t.Run("preferOAuth selects from OAuth accounts when multiple nil", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: nil, Type: AccountTypeOAuth}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 3, LastUsedAt: nil, Type: AccountTypeOAuth}, loadInfo: &AccountLoadInfo{}},
}
// preferOAuth 时,应该从 OAuth 类型中选择
oauthIDs := map[int64]bool{2: true, 3: true}
for i := 0; i < 10; i++ {
result := selectByLRU(accounts, true)
require.NotNil(t, result)
require.True(t, oauthIDs[result.account.ID], "should select from OAuth accounts")
}
})
t.Run("preferOAuth falls back to all when no OAuth", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: nil, Type: "session"}, loadInfo: &AccountLoadInfo{}},
}
// 没有 OAuth 时,从所有候选中选择
validIDs := map[int64]bool{1: true, 2: true}
for i := 0; i < 10; i++ {
result := selectByLRU(accounts, true)
require.NotNil(t, result)
require.True(t, validIDs[result.account.ID])
}
})
t.Run("preferOAuth only affects same LastUsedAt accounts", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, LastUsedAt: &earlier, Type: "session"}, loadInfo: &AccountLoadInfo{}},
{account: &Account{ID: 2, LastUsedAt: &now, Type: AccountTypeOAuth}, loadInfo: &AccountLoadInfo{}},
}
result := selectByLRU(accounts, true)
require.NotNil(t, result)
// 有不同 LastUsedAt 时,按时间选择最早的,不受 preferOAuth 影响
require.Equal(t, int64(1), result.account.ID)
})
}
func TestLayeredFilterIntegration(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
muchEarlier := now.Add(-2 * time.Hour)
t.Run("full layered selection", func(t *testing.T) {
// 模拟真实场景:多个账号,不同优先级、负载率、最后使用时间
accounts := []accountWithLoad{
// 优先级 1,负载 50%
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
// 优先级 1,负载 20%(最低)
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
// 优先级 1,负载 20%(最低),更早使用
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
// 优先级 2(较低优先)
{account: &Account{ID: 4, Priority: 2, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
}
// 1. 取优先级最小的集合 → ID: 1, 2, 3
step1 := filterByMinPriority(accounts)
require.Len(t, step1, 3)
// 2. 取负载率最低的集合 → ID: 2, 3
step2 := filterByMinLoadRate(step1)
require.Len(t, step2, 2)
// 3. LRU 选择 → ID: 3(muchEarlier 最早)
selected := selectByLRU(step2, false)
require.NotNil(t, selected)
require.Equal(t, int64(3), selected.account.ID)
})
t.Run("all same priority and load rate", func(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &muchEarlier}, loadInfo: &AccountLoadInfo{LoadRate: 50}},
}
step1 := filterByMinPriority(accounts)
require.Len(t, step1, 3)
step2 := filterByMinLoadRate(step1)
require.Len(t, step2, 3)
// LRU 选择最早的
selected := selectByLRU(step2, false)
require.NotNil(t, selected)
require.Equal(t, int64(3), selected.account.ID)
})
}
......@@ -151,6 +151,14 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int
return s.accountRepo.GetByID(fallbackCtx, accountID)
}
// UpdateAccountInCache 立即更新 Redis 中单个账号的数据(用于模型限流后立即生效)
func (s *SchedulerSnapshotService) UpdateAccountInCache(ctx context.Context, account *Account) error {
if s.cache == nil || account == nil {
return nil
}
return s.cache.SetAccount(ctx, account)
}
func (s *SchedulerSnapshotService) runInitialRebuild() {
if s.cache == nil {
return
......
package service
import "sync"
const sseScannerBuf64KSize = 64 * 1024
type sseScannerBuf64K [sseScannerBuf64KSize]byte
var sseScannerBuf64KPool = sync.Pool{
New: func() any {
return new(sseScannerBuf64K)
},
}
func getSSEScannerBuf64K() *sseScannerBuf64K {
return sseScannerBuf64KPool.Get().(*sseScannerBuf64K)
}
func putSSEScannerBuf64K(buf *sseScannerBuf64K) {
if buf == nil {
return
}
sseScannerBuf64KPool.Put(buf)
}
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSSEScannerBuf64KPool_GetPutDoesNotPanic(t *testing.T) {
buf := getSSEScannerBuf64K()
require.NotNil(t, buf)
require.Equal(t, sseScannerBuf64KSize, len(buf[:]))
buf[0] = 1
putSSEScannerBuf64K(buf)
// 允许传入 nil,确保不会 panic
putSSEScannerBuf64K(nil)
}
......@@ -23,32 +23,90 @@ import (
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
// - 模型限流超过阈值:清理
// - 模型限流未超过阈值:不清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
// nil account, error/disabled status, unschedulable, temporary unschedulable.
// nil account, error/disabled status, unschedulable, temporary unschedulable,
// and model rate limiting scenarios.
func TestShouldClearStickySession(t *testing.T) {
now := time.Now()
future := now.Add(1 * time.Hour)
past := now.Add(-1 * time.Hour)
// 短限流时间(低于阈值,不应清除粘性会话)
shortRateLimitReset := now.Add(5 * time.Second).Format(time.RFC3339)
// 长限流时间(超过阈值,应清除粘性会话)
longRateLimitReset := now.Add(30 * time.Second).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
want bool
}{
{name: "nil account", account: nil, want: false},
{name: "status error", account: &Account{Status: StatusError, Schedulable: true}, want: true},
{name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, want: true},
{name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, want: true},
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, want: true},
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, want: false},
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, want: false},
{name: "nil account", account: nil, requestedModel: "", want: false},
{name: "status error", account: &Account{Status: StatusError, Schedulable: true}, requestedModel: "", want: true},
{name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, requestedModel: "", want: true},
{name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, requestedModel: "", want: true},
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, requestedModel: "", want: true},
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, requestedModel: "", want: false},
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, requestedModel: "", want: false},
// 模型限流测试
{
name: "model rate limited short duration",
account: &Account{
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude-sonnet-4": map[string]any{
"rate_limit_reset_at": shortRateLimitReset,
},
},
},
},
requestedModel: "claude-sonnet-4",
want: false, // 低于阈值,不清除
},
{
name: "model rate limited long duration",
account: &Account{
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude-sonnet-4": map[string]any{
"rate_limit_reset_at": longRateLimitReset,
},
},
},
},
requestedModel: "claude-sonnet-4",
want: true, // 超过阈值,清除
},
{
name: "model rate limited different model",
account: &Account{
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude-sonnet-4": map[string]any{
"rate_limit_reset_at": longRateLimitReset,
},
},
},
},
requestedModel: "claude-opus-4", // 请求不同模型
want: false, // 不同模型不受影响
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, shouldClearStickySession(tt.account))
require.Equal(t, tt.want, shouldClearStickySession(tt.account, tt.requestedModel))
})
}
}
......@@ -4,10 +4,15 @@ import (
"context"
"fmt"
"log"
"math/rand/v2"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
// MaxExpiresAt is the maximum allowed expiration date (year 2099)
......@@ -35,15 +40,76 @@ type SubscriptionService struct {
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
// L1 缓存:加速中间件热路径的订阅查询
subCacheL1 *ristretto.Cache
subCacheGroup singleflight.Group
subCacheTTL time.Duration
subCacheJitter int // 抖动百分比
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
return &SubscriptionService{
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
svc := &SubscriptionService{
groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
}
svc.initSubCache(cfg)
return svc
}
// initSubCache 初始化订阅 L1 缓存
func (s *SubscriptionService) initSubCache(cfg *config.Config) {
if cfg == nil {
return
}
sc := cfg.SubscriptionCache
if sc.L1Size <= 0 || sc.L1TTLSeconds <= 0 {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(sc.L1Size) * 10,
MaxCost: int64(sc.L1Size),
BufferItems: 64,
})
if err != nil {
log.Printf("Warning: failed to init subscription L1 cache: %v", err)
return
}
s.subCacheL1 = cache
s.subCacheTTL = time.Duration(sc.L1TTLSeconds) * time.Second
s.subCacheJitter = sc.JitterPercent
}
// subCacheKey 生成订阅缓存 key(热路径,避免 fmt.Sprintf 开销)
func subCacheKey(userID, groupID int64) string {
return "sub:" + strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(groupID, 10)
}
// jitteredTTL 为 TTL 添加抖动,避免集中过期
func (s *SubscriptionService) jitteredTTL(ttl time.Duration) time.Duration {
if ttl <= 0 || s.subCacheJitter <= 0 {
return ttl
}
pct := s.subCacheJitter
if pct > 100 {
pct = 100
}
delta := float64(pct) / 100
factor := 1 - delta + rand.Float64()*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存
func (s *SubscriptionService) InvalidateSubCache(userID, groupID int64) {
if s.subCacheL1 == nil {
return
}
s.subCacheL1.Del(subCacheKey(userID, groupID))
}
// AssignSubscriptionInput 分配订阅输入
......@@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
......@@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
......@@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
......@@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
}
// 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
......@@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
......@@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
// 使用 L1 缓存 + singleflight 加速中间件热路径。
// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
key := subCacheKey(userID, groupID)
// L1 缓存命中:返回浅拷贝
if s.subCacheL1 != nil {
if v, ok := s.subCacheL1.Get(key); ok {
if sub, ok := v.(*UserSubscription); ok {
cp := *sub
return &cp, nil
}
}
}
// singleflight 防止并发击穿
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 写入 L1 缓存
if s.subCacheL1 != nil {
_ = s.subCacheL1.SetWithTTL(key, sub, 1, s.jitteredTTL(s.subCacheTTL))
}
return sub, nil
})
if err != nil {
return nil, err
}
// singleflight 返回的也是缓存指针,需要浅拷贝
cp := *value.(*UserSubscription)
return &cp, nil
}
// ListUserSubscriptions 获取用户的所有订阅
......@@ -521,10 +619,13 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
needsInvalidateCache = true
}
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
if needsInvalidateCache && s.billingCacheService != nil {
// 如果有窗口被重置,失效缓存以保持一致性
if needsInvalidateCache {
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
}
return nil
}
......@@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub
return nil
}
// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用)
// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。
// 返回 needsMaintenance 表示是否需要异步执行窗口维护。
func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, group *Group) (needsMaintenance bool, err error) {
// 1. 验证订阅状态
if sub.Status == SubscriptionStatusExpired {
return false, ErrSubscriptionExpired
}
if sub.Status == SubscriptionStatusSuspended {
return false, ErrSubscriptionSuspended
}
if sub.IsExpired() {
return false, ErrSubscriptionExpired
}
// 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户
// 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成
if sub.NeedsDailyReset() {
sub.DailyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsWeeklyReset() {
sub.WeeklyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsMonthlyReset() {
sub.MonthlyUsageUSD = 0
needsMaintenance = true
}
if !sub.IsWindowActivated() {
needsMaintenance = true
}
// 3. 检查用量限额
if !sub.CheckDailyLimit(group, 0) {
return needsMaintenance, ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, 0) {
return needsMaintenance, ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, 0) {
return needsMaintenance, ErrMonthlyLimitExceeded
}
return needsMaintenance, nil
}
// DoWindowMaintenance 异步执行窗口维护(激活+重置)
// 使用独立 context,不受请求取消影响。
// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用,
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 激活窗口(首次使用时)
if !sub.IsWindowActivated() {
if err := s.CheckAndActivateWindow(ctx, sub); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
}
// 重置过期窗口
if err := s.CheckAndResetWindows(ctx, sub); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 失效 L1 缓存,确保后续请求拿到更新后的数据
s.InvalidateSubCache(sub.UserID, sub.GroupID)
}
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
......
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ============ 临时限流单元测试 ============
// TestMatchTempUnschedKeyword 测试关键词匹配函数
func TestMatchTempUnschedKeyword(t *testing.T) {
tests := []struct {
name string
body string
keywords []string
want string
}{
{
name: "match_first",
body: "server is overloaded",
keywords: []string{"overloaded", "capacity"},
want: "overloaded",
},
{
name: "match_second",
body: "no capacity available",
keywords: []string{"overloaded", "capacity"},
want: "capacity",
},
{
name: "no_match",
body: "internal error",
keywords: []string{"overloaded", "capacity"},
want: "",
},
{
name: "empty_body",
body: "",
keywords: []string{"overloaded"},
want: "",
},
{
name: "empty_keywords",
body: "server is overloaded",
keywords: []string{},
want: "",
},
{
name: "whitespace_keyword",
body: "server is overloaded",
keywords: []string{" ", "overloaded"},
want: "overloaded",
},
{
// matchTempUnschedKeyword 期望 body 已经是小写的
// 所以要测试大小写不敏感匹配,需要传入小写的 body
name: "case_insensitive_body_lowered",
body: "server is overloaded", // body 已经是小写
keywords: []string{"OVERLOADED"}, // keyword 会被转为小写比较
want: "OVERLOADED",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := matchTempUnschedKeyword(tt.body, tt.keywords)
require.Equal(t, tt.want, got)
})
}
}
// TestAccountIsSchedulable_TempUnschedulable 测试临时限流账号不可调度
func TestAccountIsSchedulable_TempUnschedulable(t *testing.T) {
future := time.Now().Add(10 * time.Minute)
past := time.Now().Add(-10 * time.Minute)
tests := []struct {
name string
account *Account
want bool
}{
{
name: "temp_unschedulable_active",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &future,
},
want: false,
},
{
name: "temp_unschedulable_expired",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &past,
},
want: true,
},
{
name: "no_temp_unschedulable",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: nil,
},
want: true,
},
{
name: "temp_unschedulable_with_rate_limit",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &future,
RateLimitResetAt: &past, // 过期的限流不影响
},
want: false, // 临时限流生效
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.IsSchedulable()
require.Equal(t, tt.want, got)
})
}
}
// TestAccount_IsTempUnschedulableEnabled 测试临时限流开关
func TestAccount_IsTempUnschedulableEnabled(t *testing.T) {
tests := []struct {
name string
account *Account
want bool
}{
{
name: "enabled",
account: &Account{
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
},
},
want: true,
},
{
name: "disabled",
account: &Account{
Credentials: map[string]any{
"temp_unschedulable_enabled": false,
},
},
want: false,
},
{
name: "not_set",
account: &Account{
Credentials: map[string]any{},
},
want: false,
},
{
name: "nil_credentials",
account: &Account{},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.IsTempUnschedulableEnabled()
require.Equal(t, tt.want, got)
})
}
}
// TestAccount_GetTempUnschedulableRules 测试获取临时限流规则
func TestAccount_GetTempUnschedulableRules(t *testing.T) {
tests := []struct {
name string
account *Account
wantCount int
}{
{
name: "has_rules",
account: &Account{
Credentials: map[string]any{
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(5),
},
map[string]any{
"error_code": float64(500),
"keywords": []any{"internal"},
"duration_minutes": float64(10),
},
},
},
},
wantCount: 2,
},
{
name: "empty_rules",
account: &Account{
Credentials: map[string]any{
"temp_unschedulable_rules": []any{},
},
},
wantCount: 0,
},
{
name: "no_rules",
account: &Account{
Credentials: map[string]any{},
},
wantCount: 0,
},
{
name: "nil_credentials",
account: &Account{},
wantCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rules := tt.account.GetTempUnschedulableRules()
require.Len(t, rules, tt.wantCount)
})
}
}
// TestTempUnschedulableRule_Parse 测试规则解析
func TestTempUnschedulableRule_Parse(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded", "capacity"},
"duration_minutes": float64(5),
},
},
},
}
rules := account.GetTempUnschedulableRules()
require.Len(t, rules, 1)
rule := rules[0]
require.Equal(t, 503, rule.ErrorCode)
require.Equal(t, []string{"overloaded", "capacity"}, rule.Keywords)
require.Equal(t, 5, rule.DurationMinutes)
}
// TestTruncateTempUnschedMessage 测试消息截断
func TestTruncateTempUnschedMessage(t *testing.T) {
tests := []struct {
name string
body []byte
maxBytes int
want string
}{
{
name: "short_message",
body: []byte("short"),
maxBytes: 100,
want: "short",
},
{
// 截断后会 TrimSpace,所以末尾的空格会被移除
name: "truncate_long_message",
body: []byte("this is a very long message that needs to be truncated"),
maxBytes: 20,
want: "this is a very long", // 截断后 TrimSpace
},
{
name: "empty_body",
body: []byte{},
maxBytes: 100,
want: "",
},
{
name: "zero_max_bytes",
body: []byte("test"),
maxBytes: 0,
want: "",
},
{
name: "whitespace_trimmed",
body: []byte(" test "),
maxBytes: 100,
want: "test",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := truncateTempUnschedMessage(tt.body, tt.maxBytes)
require.Equal(t, tt.want, got)
})
}
}
// TestTempUnschedState 测试临时限流状态结构
func TestTempUnschedState(t *testing.T) {
now := time.Now()
until := now.Add(5 * time.Minute)
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
StatusCode: 503,
MatchedKeyword: "overloaded",
RuleIndex: 0,
ErrorMessage: "Server is overloaded",
}
require.Equal(t, 503, state.StatusCode)
require.Equal(t, "overloaded", state.MatchedKeyword)
require.Equal(t, 0, state.RuleIndex)
// 验证时间戳
require.Equal(t, until.Unix(), state.UntilUnix)
require.Equal(t, now.Unix(), state.TriggeredAtUnix)
}
// TestAccount_TempUnschedulableUntil 测试临时限流时间字段
func TestAccount_TempUnschedulableUntil(t *testing.T) {
future := time.Now().Add(10 * time.Minute)
past := time.Now().Add(-10 * time.Minute)
tests := []struct {
name string
account *Account
schedulable bool
}{
{
name: "active_temp_unsched_not_schedulable",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &future,
},
schedulable: false,
},
{
name: "expired_temp_unsched_is_schedulable",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: &past,
},
schedulable: true,
},
{
name: "nil_temp_unsched_is_schedulable",
account: &Account{
Status: StatusActive,
Schedulable: true,
TempUnschedulableUntil: nil,
},
schedulable: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.IsSchedulable()
require.Equal(t, tt.schedulable, got)
})
}
}
......@@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
}
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}
......
......@@ -3,6 +3,8 @@ package service
import (
"context"
"fmt"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
......@@ -62,13 +64,15 @@ type ChangePasswordRequest struct {
type UserService struct {
userRepo UserRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
return &UserService{
userRepo: userRepo,
authCacheInvalidator: authCacheInvalidator,
billingCache: billingCache,
}
}
......@@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCache != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}()
}
return nil
}
......
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// --- mock: UserRepository ---
type mockUserRepo struct {
updateBalanceErr error
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
}
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
if m.updateBalanceFn != nil {
return m.updateBalanceFn(ctx, id, amount)
}
return m.updateBalanceErr
}
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
// --- mock: APIKeyAuthCacheInvalidator ---
type mockAuthCacheInvalidator struct {
invalidatedUserIDs []int64
mu sync.Mutex
}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
m.mu.Lock()
defer m.mu.Unlock()
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
}
// --- mock: BillingCache ---
type mockBillingCache struct {
invalidateErr error
invalidateCallCount atomic.Int64
invalidatedUserIDs []int64
mu sync.Mutex
}
func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil }
func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil }
func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil }
func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error {
m.invalidateCallCount.Add(1)
m.mu.Lock()
defer m.mu.Unlock()
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
return m.invalidateErr
}
func (m *mockBillingCache) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) {
return nil, nil
}
func (m *mockBillingCache) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error {
return nil
}
func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error {
return nil
}
func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error {
return nil
}
// --- 测试 ---
func TestUpdateBalance_Success(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache)
err := svc.UpdateBalance(context.Background(), 42, 100.0)
require.NoError(t, err)
// 等待异步 goroutine 完成
require.Eventually(t, func() bool {
return cache.invalidateCallCount.Load() == 1
}, 2*time.Second, 10*time.Millisecond, "应异步调用 InvalidateUserBalance")
cache.mu.Lock()
defer cache.mu.Unlock()
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
}
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil) // billingCache = nil
err := svc.UpdateBalance(context.Background(), 1, 50.0)
require.NoError(t, err, "billingCache 为 nil 时不应 panic")
}
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
svc := NewUserService(repo, nil, cache)
err := svc.UpdateBalance(context.Background(), 99, 200.0)
require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
// 等待异步 goroutine 完成(即使失败也应调用)
require.Eventually(t, func() bool {
return cache.invalidateCallCount.Load() == 1
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
}
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache)
err := svc.UpdateBalance(context.Background(), 1, 100.0)
require.Error(t, err, "repo 失败时应返回错误")
require.Contains(t, err.Error(), "update balance")
// repo 失败时不应触发缓存失效
time.Sleep(100 * time.Millisecond)
require.Equal(t, int64(0), cache.invalidateCallCount.Load(),
"repo 失败时不应调用 InvalidateUserBalance")
}
func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache)
err := svc.UpdateBalance(context.Background(), 77, 300.0)
require.NoError(t, err)
// 验证 auth cache 同步失效
auth.mu.Lock()
require.Equal(t, []int64{77}, auth.invalidatedUserIDs)
auth.mu.Unlock()
// 验证 billing cache 异步失效
require.Eventually(t, func() bool {
return cache.invalidateCallCount.Load() == 1
}, 2*time.Second, 10*time.Millisecond)
}
func TestNewUserService_FieldsAssignment(t *testing.T) {
repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache)
require.NotNil(t, svc)
require.Equal(t, repo, svc.userRepo)
require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache)
}
-- Force set default Antigravity model_mapping.
--
-- Notes:
-- - Applies to both Antigravity OAuth and Upstream accounts.
-- - Overwrites existing credentials.model_mapping.
-- - Removes legacy credentials.model_whitelist.
UPDATE accounts
SET credentials = (COALESCE(credentials, '{}'::jsonb) - 'model_whitelist' - 'model_mapping') || '{
"model_mapping": {
"claude-opus-4-6": "claude-opus-4-6",
"claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview"
}
}'::jsonb
WHERE platform = 'antigravity'
AND deleted_at IS NULL;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment