Unverified Commit 9d795061 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #682 from mt21625457/pr/all-code-sync-20260228

feat(openai-ws): support websocket mode v2, optimize relay performance, enhance sora
parents bfc7b339 1d1fc019
......@@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
}
}
s.compileAPIKeyIPRules(apiKey)
return apiKey
}
......@@ -158,6 +158,14 @@ func NewAPIKeyService(
return svc
}
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
if apiKey == nil {
return
}
apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist)
apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist)
}
// GenerateKey 生成随机API Key
func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
......@@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
......@@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
......@@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
......@@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
} else {
......@@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
......@@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
......@@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
......
......@@ -308,6 +308,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
}, nil
}
// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。
// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验,
// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。
func (s *AuthService) VerifyTurnstileForRegister(ctx context.Context, token, remoteIP, verifyCode string) error {
if s.IsEmailVerifyEnabled(ctx) && strings.TrimSpace(verifyCode) != "" {
logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verify flow detected, skip duplicate Turnstile check on register")
return nil
}
return s.VerifyTurnstile(ctx, token, remoteIP)
}
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type turnstileVerifierSpy struct {
called int
lastToken string
result *TurnstileVerifyResponse
err error
}
func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) {
s.called++
s.lastToken = token
if s.err != nil {
return nil, s.err
}
if s.result != nil {
return s.result, nil
}
return &TurnstileVerifyResponse{Success: true}, nil
}
func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService {
cfg := &config.Config{
Server: config.ServerConfig{
Mode: "release",
},
Turnstile: config.TurnstileConfig{
Required: true,
},
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
turnstileService := NewTurnstileService(settingService, verifier)
return NewAuthService(
&userRepoStub{},
nil, // redeemRepo
nil, // refreshTokenCache
cfg,
settingService,
nil, // emailService
turnstileService,
nil, // emailQueueService
nil, // promoService
)
}
func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
SettingKeyRegistrationEnabled: "true",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 0, verifier.called)
}
func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "")
require.ErrorIs(t, err, ErrTurnstileVerificationFailed)
}
func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "false",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 1, verifier.called)
require.Equal(t, "turnstile-token", verifier.lastToken)
}
......@@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
......@@ -10,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"golang.org/x/sync/singleflight"
)
// 错误定义
......@@ -58,6 +60,7 @@ const (
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
balanceLoadTimeout = 3 * time.Second
)
// cacheWriteTask 缓存写入任务
......@@ -82,6 +85,9 @@ type BillingCacheService struct {
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
cacheWriteMu sync.RWMutex
stopped atomic.Bool
balanceLoadSF singleflight.Group
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
......@@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
// Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() {
if s.cacheWriteChan == nil {
s.stopped.Store(true)
s.cacheWriteMu.Lock()
ch := s.cacheWriteChan
if ch != nil {
close(ch)
}
s.cacheWriteMu.Unlock()
if ch == nil {
return
}
close(s.cacheWriteChan)
s.cacheWriteWg.Wait()
s.cacheWriteMu.Lock()
if s.cacheWriteChan == ch {
s.cacheWriteChan = nil
}
s.cacheWriteMu.Unlock()
})
}
func (s *BillingCacheService) startCacheWriteWorkers() {
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
ch := make(chan cacheWriteTask, cacheWriteBufferSize)
s.cacheWriteChan = ch
for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1)
go s.cacheWriteWorker()
go s.cacheWriteWorker(ch)
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.cacheWriteChan == nil {
if s.stopped.Load() {
s.logCacheWriteDrop(task, "closed")
return false
}
defer func() {
if recovered := recover(); recovered != nil {
// 队列已关闭时可能触发 panic,记录后静默失败。
s.cacheWriteMu.RLock()
defer s.cacheWriteMu.RUnlock()
if s.cacheWriteChan == nil {
s.logCacheWriteDrop(task, "closed")
enqueued = false
return false
}
}()
select {
case s.cacheWriteChan <- task:
return true
......@@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b
}
}
func (s *BillingCacheService) cacheWriteWorker() {
func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
defer s.cacheWriteWg.Done()
for task := range s.cacheWriteChan {
for task := range ch {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind {
case cacheWriteSetBalance:
......@@ -243,10 +266,14 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
return balance, nil
}
// 缓存未命中,从数据库读取
balance, err = s.getUserBalanceFromDB(ctx, userID)
// 缓存未命中:singleflight 合并同一 userID 的并发回源请求。
value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
defer cancel()
balance, err := s.getUserBalanceFromDB(loadCtx, userID)
if err != nil {
return 0, err
return nil, err
}
// 异步建立缓存
......@@ -255,7 +282,15 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
userID: userID,
balance: balance,
})
return balance, nil
})
if err != nil {
return 0, err
}
balance, ok := value.(float64)
if !ok {
return 0, fmt.Errorf("unexpected balance type: %T", value)
}
return balance, nil
}
......
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheMissStub struct {
setBalanceCalls atomic.Int64
}
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
s.setBalanceCalls.Add(1)
return nil
}
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
type balanceLoadUserRepoStub struct {
mockUserRepo
calls atomic.Int64
delay time.Duration
balance float64
}
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
s.calls.Add(1)
if s.delay > 0 {
select {
case <-time.After(s.delay):
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &User{ID: id, Balance: s.balance}, nil
}
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
balCh := make(chan float64, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
bal, err := svc.GetUserBalance(context.Background(), 99)
errCh <- err
balCh <- bal
}()
}
close(start)
wg.Wait()
close(errCh)
close(balCh)
for err := range errCh {
require.NoError(t, err)
}
for bal := range balCh {
require.Equal(t, 12.34, bal)
}
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
require.Eventually(t, func() bool {
return cache.setBalanceCalls.Load() >= 1
}, time.Second, 10*time.Millisecond)
}
......@@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
}
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: 1,
amount: 1,
})
require.False(t, enqueued)
}
......@@ -78,7 +78,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku {
if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku {
return true // 绕过 system prompt 检查,UA 已在 Step 1 验证
}
......
......@@ -3,8 +3,10 @@ package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"encoding/binary"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
......@@ -18,6 +20,7 @@ type ConcurrencyCache interface {
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
......@@ -42,15 +45,25 @@ type ConcurrencyCache interface {
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
}
// generateRequestID generates a unique request ID for concurrency slot tracking
// Uses 8 random bytes (16 hex chars) for uniqueness
func generateRequestID() string {
var (
requestIDPrefix = initRequestIDPrefix()
requestIDCounter atomic.Uint64
)
func initRequestIDPrefix() string {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
// Fallback to nanosecond timestamp (extremely rare case)
return fmt.Sprintf("%x", time.Now().UnixNano())
if _, err := rand.Read(b); err == nil {
return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
}
return hex.EncodeToString(b)
fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
return "r" + strconv.FormatUint(fallback, 36)
}
// generateRequestID generates a unique request ID for concurrency slot tracking.
// Format: {process_random_prefix}-{base36_counter}
func generateRequestID() string {
seq := requestIDCounter.Add(1)
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
}
const (
......@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int)
for _, accountID := range accountIDs {
count, err := s.cache.GetAccountConcurrency(ctx, accountID)
if err != nil {
// If key doesn't exist in Redis, count is 0
count = 0
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
result[accountID] = count
if s.cache == nil {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
result[accountID] = 0
}
return result, nil
}
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
}
......@@ -5,6 +5,8 @@ package service
import (
"context"
"errors"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/require"
......@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
return c.concurrency, c.concurrencyErr
}
func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
if c.concurrencyErr != nil {
return nil, c.concurrencyErr
}
result[accountID] = c.concurrency
}
return result, nil
}
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
return c.waitAllowed, c.waitErr
}
......@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
require.True(t, result.Acquired)
}
func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) {
id1 := generateRequestID()
id2 := generateRequestID()
require.NotEmpty(t, id1)
require.NotEmpty(t, id2)
p1 := strings.Split(id1, "-")
p2 := strings.Split(id2, "-")
require.Len(t, p1, 2)
require.Len(t, p2, 2)
require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致")
n1, err := strconv.ParseUint(p1[1], 36, 64)
require.NoError(t, err)
n2, err := strconv.ParseUint(p2[1], 36, 64)
require.NoError(t, err)
require.Equal(t, n1+1, n2, "计数器应单调递增")
}
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
expected := map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},
......
......@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
......
package service
import "context"
type DataManagementPostgresConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
ContainerName string `json:"container_name"`
}
type DataManagementRedisConfig struct {
Addr string `json:"addr"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementS3Config struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key,omitempty"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type DataManagementConfig struct {
SourceMode string `json:"source_mode"`
BackupRoot string `json:"backup_root"`
SQLitePath string `json:"sqlite_path,omitempty"`
RetentionDays int32 `json:"retention_days"`
KeepLast int32 `json:"keep_last"`
ActivePostgresID string `json:"active_postgres_profile_id"`
ActiveRedisID string `json:"active_redis_profile_id"`
Postgres DataManagementPostgresConfig `json:"postgres"`
Redis DataManagementRedisConfig `json:"redis"`
S3 DataManagementS3Config `json:"s3"`
ActiveS3ProfileID string `json:"active_s3_profile_id"`
}
type DataManagementTestS3Result struct {
OK bool `json:"ok"`
Message string `json:"message"`
}
type DataManagementCreateBackupJobInput struct {
BackupType string
UploadToS3 bool
TriggeredBy string
IdempotencyKey string
S3ProfileID string
PostgresID string
RedisID string
}
type DataManagementListBackupJobsInput struct {
PageSize int32
PageToken string
Status string
BackupType string
}
type DataManagementArtifactInfo struct {
LocalPath string `json:"local_path"`
SizeBytes int64 `json:"size_bytes"`
SHA256 string `json:"sha256"`
}
type DataManagementS3ObjectInfo struct {
Bucket string `json:"bucket"`
Key string `json:"key"`
ETag string `json:"etag"`
}
type DataManagementBackupJob struct {
JobID string `json:"job_id"`
BackupType string `json:"backup_type"`
Status string `json:"status"`
TriggeredBy string `json:"triggered_by"`
IdempotencyKey string `json:"idempotency_key,omitempty"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id,omitempty"`
PostgresID string `json:"postgres_profile_id,omitempty"`
RedisID string `json:"redis_profile_id,omitempty"`
StartedAt string `json:"started_at,omitempty"`
FinishedAt string `json:"finished_at,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Artifact DataManagementArtifactInfo `json:"artifact"`
S3Object DataManagementS3ObjectInfo `json:"s3"`
}
type DataManagementSourceProfile struct {
SourceType string `json:"source_type"`
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Config DataManagementSourceConfig `json:"config"`
PasswordConfigured bool `json:"password_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementSourceConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
Addr string `json:"addr"`
Username string `json:"username"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementCreateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
SetActive bool
}
type DataManagementUpdateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
}
type DataManagementS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
S3 DataManagementS3Config `json:"s3"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementCreateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
SetActive bool
}
type DataManagementUpdateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
}
type DataManagementListBackupJobsResult struct {
Items []DataManagementBackupJob `json:"items"`
NextPageToken string `json:"next_page_token,omitempty"`
}
func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) {
_ = ctx
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) {
_, _ = ctx, cfg
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) {
_, _ = ctx, sourceType
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error {
_, _, _ = ctx, sourceType, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) {
_, _, _ = ctx, sourceType, profileID
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) {
_, _ = ctx, cfg
return DataManagementTestS3Result{}, s.deprecatedError()
}
func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) {
_ = ctx
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error {
_, _ = ctx, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) {
_, _ = ctx, profileID
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) {
_, _ = ctx, input
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) {
_, _ = ctx, input
return DataManagementListBackupJobsResult{}, s.deprecatedError()
}
func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) {
_, _ = ctx, jobID
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) deprecatedError() error {
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
}
package service
import (
"context"
"path/filepath"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "datamanagement.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 0)
_, err := svc.GetConfig(context.Background())
assertDeprecatedDataManagementError(t, err, socketPath)
_, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"})
assertDeprecatedDataManagementError(t, err, socketPath)
err = svc.DeleteS3Profile(context.Background(), "s3-default")
assertDeprecatedDataManagementError(t, err, socketPath)
}
func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) {
t.Helper()
require.Error(t, err)
statusCode, status := infraerrors.ToHTTP(err)
require.Equal(t, 503, statusCode)
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
require.Equal(t, socketPath, status.Metadata["socket_path"])
}
package service
import (
"context"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock"
LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock"
DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED"
DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING"
DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE"
// Deprecated: keep old names for compatibility.
DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath
BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason
BackupAgentUnavailableReason = DataManagementAgentUnavailableReason
)
var (
ErrDataManagementDeprecated = infraerrors.ServiceUnavailable(
DataManagementDeprecatedReason,
"data management feature is deprecated",
)
ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable(
DataManagementAgentSocketMissingReason,
"data management agent socket is missing",
)
ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable(
DataManagementAgentUnavailableReason,
"data management agent is unavailable",
)
// Deprecated: keep old names for compatibility.
ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing
ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable
)
type DataManagementAgentHealth struct {
Enabled bool
Reason string
SocketPath string
Agent *DataManagementAgentInfo
}
type DataManagementAgentInfo struct {
Status string
Version string
UptimeSeconds int64
}
type DataManagementService struct {
socketPath string
}
func NewDataManagementService() *DataManagementService {
return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond)
}
func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService {
_ = dialTimeout
path := strings.TrimSpace(socketPath)
if path == "" {
path = DefaultDataManagementAgentSocketPath
}
return &DataManagementService{
socketPath: path,
}
}
func (s *DataManagementService) SocketPath() string {
if s == nil || strings.TrimSpace(s.socketPath) == "" {
return DefaultDataManagementAgentSocketPath
}
return s.socketPath
}
func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth {
_ = ctx
return DataManagementAgentHealth{
Enabled: false,
Reason: DataManagementDeprecatedReason,
SocketPath: s.SocketPath(),
}
}
func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error {
_ = ctx
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
}
package service
import (
"context"
"path/filepath"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "unused.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 0)
health := svc.GetAgentHealth(context.Background())
require.False(t, health.Enabled)
require.Equal(t, DataManagementDeprecatedReason, health.Reason)
require.Equal(t, socketPath, health.SocketPath)
require.Nil(t, health.Agent)
}
func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "unused.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 100)
err := svc.EnsureAgentEnabled(context.Background())
require.Error(t, err)
statusCode, status := infraerrors.ToHTTP(err)
require.Equal(t, 503, statusCode)
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
require.Equal(t, socketPath, status.Metadata["socket_path"])
}
......@@ -104,6 +104,7 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
......@@ -170,6 +171,27 @@ const (
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
// =========================
// Sora S3 存储配置
// =========================
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
// =========================
// Sora 用户存储配额
// =========================
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
......
......@@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE
wantPassthrough: true,
},
{
name: "404 generic not found passes through as 404",
name: "404 generic not found does not passthrough",
statusCode: http.StatusNotFound,
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
wantPassthrough: true,
wantPassthrough: false,
},
{
name: "400 Invalid URL does not passthrough",
......
......@@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) {
require.Contains(t, extended, claude.BetaClaudeCode)
require.Len(t, extended, len(claude.DroppedBetas)+1)
}
func TestBuildBetaTokenSet(t *testing.T) {
got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"})
require.Len(t, got, 2)
require.Contains(t, got, "foo")
require.Contains(t, got, "bar")
require.NotContains(t, got, "")
empty := buildBetaTokenSet(nil)
require.Empty(t, empty)
}
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
got := stripBetaTokensWithSet(header, map[string]struct{}{})
require.Equal(t, header, got)
}
func TestIsCountTokensUnsupported404(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
want bool
}{
{
name: "exact endpoint not found",
statusCode: 404,
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
want: true,
},
{
name: "contains count_tokens and not found",
statusCode: 404,
body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`,
want: true,
},
{
name: "generic 404",
statusCode: 404,
body: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
want: false,
},
{
name: "404 with empty error message",
statusCode: 404,
body: `{"error":{"message":"","type":"not_found_error"}}`,
want: false,
},
{
name: "non-404 status",
statusCode: 400,
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body))
require.Equal(t, tt.want, got)
})
}
}
......@@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun
return 0, nil
}
func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
result[accountID] = 0
}
return result, nil
}
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment