Commit bb664d9b authored by yangjianbo's avatar yangjianbo
Browse files

feat(sync): full code sync from release

parent bfc7b339
...@@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes, SupportedModelScopes: snapshot.Group.SupportedModelScopes,
} }
} }
s.compileAPIKeyIPRules(apiKey)
return apiKey return apiKey
} }
...@@ -158,6 +158,14 @@ func NewAPIKeyService( ...@@ -158,6 +158,14 @@ func NewAPIKeyService(
return svc return svc
} }
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
if apiKey == nil {
return
}
apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist)
apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist)
}
// GenerateKey 生成随机API Key // GenerateKey 生成随机API Key
func (s *APIKeyService) GenerateKey() (string, error) { func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据 // 生成32字节随机数据
...@@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK ...@@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
} }
s.InvalidateAuthCacheByKey(ctx, apiKey.Key) s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil return apiKey, nil
} }
...@@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) ...@@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil return apiKey, nil
} }
...@@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro ...@@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil return apiKey, nil
} }
} }
...@@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro ...@@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil return apiKey, nil
} }
} else { } else {
...@@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro ...@@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil return apiKey, nil
} }
} }
...@@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro ...@@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
apiKey.Key = key apiKey.Key = key
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil return apiKey, nil
} }
...@@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
} }
s.InvalidateAuthCacheByKey(ctx, apiKey.Key) s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil return apiKey, nil
} }
......
...@@ -308,6 +308,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S ...@@ -308,6 +308,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
}, nil }, nil
} }
// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。
// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验,
// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。
func (s *AuthService) VerifyTurnstileForRegister(ctx context.Context, token, remoteIP, verifyCode string) error {
if s.IsEmailVerifyEnabled(ctx) && strings.TrimSpace(verifyCode) != "" {
logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verify flow detected, skip duplicate Turnstile check on register")
return nil
}
return s.VerifyTurnstile(ctx, token, remoteIP)
}
// VerifyTurnstile 验证Turnstile token // VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type turnstileVerifierSpy struct {
called int
lastToken string
result *TurnstileVerifyResponse
err error
}
func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) {
s.called++
s.lastToken = token
if s.err != nil {
return nil, s.err
}
if s.result != nil {
return s.result, nil
}
return &TurnstileVerifyResponse{Success: true}, nil
}
func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService {
cfg := &config.Config{
Server: config.ServerConfig{
Mode: "release",
},
Turnstile: config.TurnstileConfig{
Required: true,
},
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
turnstileService := NewTurnstileService(settingService, verifier)
return NewAuthService(
&userRepoStub{},
nil, // redeemRepo
nil, // refreshTokenCache
cfg,
settingService,
nil, // emailService
turnstileService,
nil, // emailQueueService
nil, // promoService
)
}
func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
SettingKeyRegistrationEnabled: "true",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 0, verifier.called)
}
func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "")
require.ErrorIs(t, err, ErrTurnstileVerificationFailed)
}
func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "false",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 1, verifier.called)
require.Equal(t, "turnstile-token", verifier.lastToken)
}
...@@ -3,6 +3,7 @@ package service ...@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
...@@ -10,6 +11,7 @@ import ( ...@@ -10,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"golang.org/x/sync/singleflight"
) )
// 错误定义 // 错误定义
...@@ -58,6 +60,7 @@ const ( ...@@ -58,6 +60,7 @@ const (
cacheWriteBufferSize = 1000 // 任务队列缓冲大小 cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔 cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
balanceLoadTimeout = 3 * time.Second
) )
// cacheWriteTask 缓存写入任务 // cacheWriteTask 缓存写入任务
...@@ -82,6 +85,9 @@ type BillingCacheService struct { ...@@ -82,6 +85,9 @@ type BillingCacheService struct {
cacheWriteChan chan cacheWriteTask cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once cacheWriteStopOnce sync.Once
cacheWriteMu sync.RWMutex
stopped atomic.Bool
balanceLoadSF singleflight.Group
// 丢弃日志节流计数器(减少高负载下日志噪音) // 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64 cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64 cacheWriteDropFullLastLog int64
...@@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo ...@@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
// Stop 关闭缓存写入工作池 // Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() { func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() { s.cacheWriteStopOnce.Do(func() {
if s.cacheWriteChan == nil { s.stopped.Store(true)
s.cacheWriteMu.Lock()
ch := s.cacheWriteChan
if ch != nil {
close(ch)
}
s.cacheWriteMu.Unlock()
if ch == nil {
return return
} }
close(s.cacheWriteChan)
s.cacheWriteWg.Wait() s.cacheWriteWg.Wait()
s.cacheWriteChan = nil
s.cacheWriteMu.Lock()
if s.cacheWriteChan == ch {
s.cacheWriteChan = nil
}
s.cacheWriteMu.Unlock()
}) })
} }
func (s *BillingCacheService) startCacheWriteWorkers() { func (s *BillingCacheService) startCacheWriteWorkers() {
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize) ch := make(chan cacheWriteTask, cacheWriteBufferSize)
s.cacheWriteChan = ch
for i := 0; i < cacheWriteWorkerCount; i++ { for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1) s.cacheWriteWg.Add(1)
go s.cacheWriteWorker() go s.cacheWriteWorker(ch)
} }
} }
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。 // enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) { func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.stopped.Load() {
s.logCacheWriteDrop(task, "closed")
return false
}
s.cacheWriteMu.RLock()
defer s.cacheWriteMu.RUnlock()
if s.cacheWriteChan == nil { if s.cacheWriteChan == nil {
s.logCacheWriteDrop(task, "closed")
return false return false
} }
defer func() {
if recovered := recover(); recovered != nil {
// 队列已关闭时可能触发 panic,记录后静默失败。
s.logCacheWriteDrop(task, "closed")
enqueued = false
}
}()
select { select {
case s.cacheWriteChan <- task: case s.cacheWriteChan <- task:
return true return true
...@@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b ...@@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b
} }
} }
func (s *BillingCacheService) cacheWriteWorker() { func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
defer s.cacheWriteWg.Done() defer s.cacheWriteWg.Done()
for task := range s.cacheWriteChan { for task := range ch {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind { switch task.kind {
case cacheWriteSetBalance: case cacheWriteSetBalance:
...@@ -243,19 +266,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) ...@@ -243,19 +266,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
return balance, nil return balance, nil
} }
// 缓存未命中,从数据库读取 // 缓存未命中:singleflight 合并同一 userID 的并发回源请求。
balance, err = s.getUserBalanceFromDB(ctx, userID) value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
defer cancel()
balance, err := s.getUserBalanceFromDB(loadCtx, userID)
if err != nil {
return nil, err
}
// 异步建立缓存
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance,
userID: userID,
balance: balance,
})
return balance, nil
})
if err != nil { if err != nil {
return 0, err return 0, err
} }
balance, ok := value.(float64)
// 异步建立缓存 if !ok {
_ = s.enqueueCacheWrite(cacheWriteTask{ return 0, fmt.Errorf("unexpected balance type: %T", value)
kind: cacheWriteSetBalance, }
userID: userID,
balance: balance,
})
return balance, nil return balance, nil
} }
......
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheMissStub struct {
setBalanceCalls atomic.Int64
}
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
s.setBalanceCalls.Add(1)
return nil
}
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
type balanceLoadUserRepoStub struct {
mockUserRepo
calls atomic.Int64
delay time.Duration
balance float64
}
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
s.calls.Add(1)
if s.delay > 0 {
select {
case <-time.After(s.delay):
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &User{ID: id, Balance: s.balance}, nil
}
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
balCh := make(chan float64, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
bal, err := svc.GetUserBalance(context.Background(), 99)
errCh <- err
balCh <- bal
}()
}
close(start)
wg.Wait()
close(errCh)
close(balCh)
for err := range errCh {
require.NoError(t, err)
}
for bal := range balCh {
require.Equal(t, 12.34, bal)
}
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
require.Eventually(t, func() bool {
return cache.setBalanceCalls.Load() >= 1
}, time.Second, 10*time.Millisecond)
}
...@@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { ...@@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0 return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond) }, 2*time.Second, 10*time.Millisecond)
} }
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: 1,
amount: 1,
})
require.False(t, enqueued)
}
...@@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) { ...@@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) {
// 费率倍数 1.5x // 费率倍数 1.5x
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5) cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5 require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5 require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5
// 费率倍数 2.0x // 费率倍数 2.0x
......
...@@ -78,7 +78,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo ...@@ -78,7 +78,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过 // Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt // 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku { if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku {
return true // 绕过 system prompt 检查,UA 已在 Step 1 验证 return true // 绕过 system prompt 检查,UA 已在 Step 1 验证
} }
......
...@@ -3,8 +3,10 @@ package service ...@@ -3,8 +3,10 @@ package service
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/binary"
"fmt" "os"
"strconv"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
...@@ -18,6 +20,7 @@ type ConcurrencyCache interface { ...@@ -18,6 +20,7 @@ type ConcurrencyCache interface {
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
// 账号等待队列(账号级) // 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
...@@ -42,15 +45,25 @@ type ConcurrencyCache interface { ...@@ -42,15 +45,25 @@ type ConcurrencyCache interface {
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
} }
// generateRequestID generates a unique request ID for concurrency slot tracking var (
// Uses 8 random bytes (16 hex chars) for uniqueness requestIDPrefix = initRequestIDPrefix()
func generateRequestID() string { requestIDCounter atomic.Uint64
)
func initRequestIDPrefix() string {
b := make([]byte, 8) b := make([]byte, 8)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err == nil {
// Fallback to nanosecond timestamp (extremely rare case) return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
return fmt.Sprintf("%x", time.Now().UnixNano())
} }
return hex.EncodeToString(b) fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
return "r" + strconv.FormatUint(fallback, 36)
}
// generateRequestID generates a unique request ID for concurrency slot tracking.
// Format: {process_random_prefix}-{base36_counter}
func generateRequestID() string {
seq := requestIDCounter.Add(1)
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
} }
const ( const (
...@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor ...@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count // Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int) if len(accountIDs) == 0 {
return map[int64]int{}, nil
for _, accountID := range accountIDs { }
count, err := s.cache.GetAccountConcurrency(ctx, accountID) if s.cache == nil {
if err != nil { result := make(map[int64]int, len(accountIDs))
// If key doesn't exist in Redis, count is 0 for _, accountID := range accountIDs {
count = 0 result[accountID] = 0
} }
result[accountID] = count return result, nil
} }
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
return result, nil
} }
...@@ -5,6 +5,8 @@ package service ...@@ -5,6 +5,8 @@ package service
import ( import (
"context" "context"
"errors" "errors"
"strconv"
"strings"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -12,20 +14,20 @@ import ( ...@@ -12,20 +14,20 @@ import (
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩 // stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
type stubConcurrencyCacheForTest struct { type stubConcurrencyCacheForTest struct {
acquireResult bool acquireResult bool
acquireErr error acquireErr error
releaseErr error releaseErr error
concurrency int concurrency int
concurrencyErr error concurrencyErr error
waitAllowed bool waitAllowed bool
waitErr error waitErr error
waitCount int waitCount int
waitCountErr error waitCountErr error
loadBatch map[int64]*AccountLoadInfo loadBatch map[int64]*AccountLoadInfo
loadBatchErr error loadBatchErr error
usersLoadBatch map[int64]*UserLoadInfo usersLoadBatch map[int64]*UserLoadInfo
usersLoadErr error usersLoadErr error
cleanupErr error cleanupErr error
// 记录调用 // 记录调用
releasedAccountIDs []int64 releasedAccountIDs []int64
...@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco ...@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
return c.concurrency, c.concurrencyErr return c.concurrency, c.concurrencyErr
} }
func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
if c.concurrencyErr != nil {
return nil, c.concurrencyErr
}
result[accountID] = c.concurrency
}
return result, nil
}
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
return c.waitAllowed, c.waitErr return c.waitAllowed, c.waitErr
} }
...@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) { ...@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
require.True(t, result.Acquired) require.True(t, result.Acquired)
} }
func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) {
id1 := generateRequestID()
id2 := generateRequestID()
require.NotEmpty(t, id1)
require.NotEmpty(t, id2)
p1 := strings.Split(id1, "-")
p2 := strings.Split(id2, "-")
require.Len(t, p1, 2)
require.Len(t, p2, 2)
require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致")
n1, err := strconv.ParseUint(p1[1], 36, 64)
require.NoError(t, err)
n2, err := strconv.ParseUint(p2[1], 36, 64)
require.NoError(t, err)
require.Equal(t, n1+1, n2, "计数器应单调递增")
}
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) { func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
expected := map[int64]*AccountLoadInfo{ expected := map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60}, 1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},
......
...@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D ...@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil return stats, nil
} }
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil { if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err) return nil, fmt.Errorf("get usage trend with filters: %w", err)
} }
return trend, nil return trend, nil
} }
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil { if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err) return nil, fmt.Errorf("get model stats with filters: %w", err)
} }
......
package service
import "context"
type DataManagementPostgresConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
ContainerName string `json:"container_name"`
}
type DataManagementRedisConfig struct {
Addr string `json:"addr"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementS3Config struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key,omitempty"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type DataManagementConfig struct {
SourceMode string `json:"source_mode"`
BackupRoot string `json:"backup_root"`
SQLitePath string `json:"sqlite_path,omitempty"`
RetentionDays int32 `json:"retention_days"`
KeepLast int32 `json:"keep_last"`
ActivePostgresID string `json:"active_postgres_profile_id"`
ActiveRedisID string `json:"active_redis_profile_id"`
Postgres DataManagementPostgresConfig `json:"postgres"`
Redis DataManagementRedisConfig `json:"redis"`
S3 DataManagementS3Config `json:"s3"`
ActiveS3ProfileID string `json:"active_s3_profile_id"`
}
type DataManagementTestS3Result struct {
OK bool `json:"ok"`
Message string `json:"message"`
}
type DataManagementCreateBackupJobInput struct {
BackupType string
UploadToS3 bool
TriggeredBy string
IdempotencyKey string
S3ProfileID string
PostgresID string
RedisID string
}
type DataManagementListBackupJobsInput struct {
PageSize int32
PageToken string
Status string
BackupType string
}
type DataManagementArtifactInfo struct {
LocalPath string `json:"local_path"`
SizeBytes int64 `json:"size_bytes"`
SHA256 string `json:"sha256"`
}
type DataManagementS3ObjectInfo struct {
Bucket string `json:"bucket"`
Key string `json:"key"`
ETag string `json:"etag"`
}
type DataManagementBackupJob struct {
JobID string `json:"job_id"`
BackupType string `json:"backup_type"`
Status string `json:"status"`
TriggeredBy string `json:"triggered_by"`
IdempotencyKey string `json:"idempotency_key,omitempty"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id,omitempty"`
PostgresID string `json:"postgres_profile_id,omitempty"`
RedisID string `json:"redis_profile_id,omitempty"`
StartedAt string `json:"started_at,omitempty"`
FinishedAt string `json:"finished_at,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Artifact DataManagementArtifactInfo `json:"artifact"`
S3Object DataManagementS3ObjectInfo `json:"s3"`
}
type DataManagementSourceProfile struct {
SourceType string `json:"source_type"`
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Config DataManagementSourceConfig `json:"config"`
PasswordConfigured bool `json:"password_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementSourceConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
Addr string `json:"addr"`
Username string `json:"username"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementCreateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
SetActive bool
}
type DataManagementUpdateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
}
type DataManagementS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
S3 DataManagementS3Config `json:"s3"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementCreateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
SetActive bool
}
type DataManagementUpdateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
}
type DataManagementListBackupJobsResult struct {
Items []DataManagementBackupJob `json:"items"`
NextPageToken string `json:"next_page_token,omitempty"`
}
func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) {
_ = ctx
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) {
_, _ = ctx, cfg
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) {
_, _ = ctx, sourceType
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error {
_, _, _ = ctx, sourceType, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) {
_, _, _ = ctx, sourceType, profileID
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) {
_, _ = ctx, cfg
return DataManagementTestS3Result{}, s.deprecatedError()
}
func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) {
_ = ctx
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error {
_, _ = ctx, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) {
_, _ = ctx, profileID
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) {
_, _ = ctx, input
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) {
_, _ = ctx, input
return DataManagementListBackupJobsResult{}, s.deprecatedError()
}
func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) {
_, _ = ctx, jobID
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) deprecatedError() error {
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
}
package service
import (
"context"
"path/filepath"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "datamanagement.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 0)
_, err := svc.GetConfig(context.Background())
assertDeprecatedDataManagementError(t, err, socketPath)
_, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"})
assertDeprecatedDataManagementError(t, err, socketPath)
err = svc.DeleteS3Profile(context.Background(), "s3-default")
assertDeprecatedDataManagementError(t, err, socketPath)
}
func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) {
t.Helper()
require.Error(t, err)
statusCode, status := infraerrors.ToHTTP(err)
require.Equal(t, 503, statusCode)
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
require.Equal(t, socketPath, status.Metadata["socket_path"])
}
package service
import (
"context"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock"
LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock"
DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED"
DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING"
DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE"
// Deprecated: keep old names for compatibility.
DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath
BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason
BackupAgentUnavailableReason = DataManagementAgentUnavailableReason
)
var (
ErrDataManagementDeprecated = infraerrors.ServiceUnavailable(
DataManagementDeprecatedReason,
"data management feature is deprecated",
)
ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable(
DataManagementAgentSocketMissingReason,
"data management agent socket is missing",
)
ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable(
DataManagementAgentUnavailableReason,
"data management agent is unavailable",
)
// Deprecated: keep old names for compatibility.
ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing
ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable
)
type DataManagementAgentHealth struct {
Enabled bool
Reason string
SocketPath string
Agent *DataManagementAgentInfo
}
type DataManagementAgentInfo struct {
Status string
Version string
UptimeSeconds int64
}
type DataManagementService struct {
socketPath string
dialTimeout time.Duration
}
func NewDataManagementService() *DataManagementService {
return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond)
}
func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService {
path := strings.TrimSpace(socketPath)
if path == "" {
path = DefaultDataManagementAgentSocketPath
}
if dialTimeout <= 0 {
dialTimeout = 500 * time.Millisecond
}
return &DataManagementService{
socketPath: path,
dialTimeout: dialTimeout,
}
}
func (s *DataManagementService) SocketPath() string {
if s == nil || strings.TrimSpace(s.socketPath) == "" {
return DefaultDataManagementAgentSocketPath
}
return s.socketPath
}
func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth {
_ = ctx
return DataManagementAgentHealth{
Enabled: false,
Reason: DataManagementDeprecatedReason,
SocketPath: s.SocketPath(),
}
}
func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error {
_ = ctx
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
}
package service
import (
"context"
"path/filepath"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "unused.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 0)
health := svc.GetAgentHealth(context.Background())
require.False(t, health.Enabled)
require.Equal(t, DataManagementDeprecatedReason, health.Reason)
require.Equal(t, socketPath, health.SocketPath)
require.Nil(t, health.Agent)
}
func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "unused.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 100)
err := svc.EnsureAgentEnabled(context.Background())
require.Error(t, err)
statusCode, status := infraerrors.ToHTTP(err)
require.Equal(t, 503, statusCode)
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
require.Equal(t, socketPath, status.Metadata["socket_path"])
}
...@@ -104,6 +104,7 @@ const ( ...@@ -104,6 +104,7 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置 // OEM设置
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
...@@ -170,6 +171,27 @@ const ( ...@@ -170,6 +171,27 @@ const (
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling. // SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings = "stream_timeout_settings" SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
// =========================
// Sora S3 存储配置
// =========================
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
// =========================
// Sora 用户存储配额
// =========================
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
) )
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
......
...@@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE ...@@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE
wantPassthrough: true, wantPassthrough: true,
}, },
{ {
name: "404 generic not found passes through as 404", name: "404 generic not found does not passthrough",
statusCode: http.StatusNotFound, statusCode: http.StatusNotFound,
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`, respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
wantPassthrough: true, wantPassthrough: false,
}, },
{ {
name: "400 Invalid URL does not passthrough", name: "400 Invalid URL does not passthrough",
......
...@@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) { ...@@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) {
require.Contains(t, extended, claude.BetaClaudeCode) require.Contains(t, extended, claude.BetaClaudeCode)
require.Len(t, extended, len(claude.DroppedBetas)+1) require.Len(t, extended, len(claude.DroppedBetas)+1)
} }
func TestBuildBetaTokenSet(t *testing.T) {
got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"})
require.Len(t, got, 2)
require.Contains(t, got, "foo")
require.Contains(t, got, "bar")
require.NotContains(t, got, "")
empty := buildBetaTokenSet(nil)
require.Empty(t, empty)
}
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
got := stripBetaTokensWithSet(header, map[string]struct{}{})
require.Equal(t, header, got)
}
func TestIsCountTokensUnsupported404(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
want bool
}{
{
name: "exact endpoint not found",
statusCode: 404,
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
want: true,
},
{
name: "contains count_tokens and not found",
statusCode: 404,
body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`,
want: true,
},
{
name: "generic 404",
statusCode: 404,
body: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
want: false,
},
{
name: "404 with empty error message",
statusCode: 404,
body: `{"error":{"message":"","type":"not_found_error"}}`,
want: false,
},
{
name: "non-404 status",
statusCode: 400,
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body))
require.Equal(t, tt.want, got)
})
}
}
...@@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun ...@@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun
return 0, nil return 0, nil
} }
func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
result[accountID] = 0
}
return result, nil
}
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil return true, nil
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment