Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
bb664d9b
Commit
bb664d9b
authored
Feb 28, 2026
by
yangjianbo
Browse files
feat(sync): full code sync from release
parent
bfc7b339
Changes
244
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
244 of 244+
files are displayed.
Plain diff
Email patch
backend/internal/service/api_key_auth_cache_impl.go
View file @
bb664d9b
...
@@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
...
@@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes
:
snapshot
.
Group
.
SupportedModelScopes
,
SupportedModelScopes
:
snapshot
.
Group
.
SupportedModelScopes
,
}
}
}
}
s
.
compileAPIKeyIPRules
(
apiKey
)
return
apiKey
return
apiKey
}
}
backend/internal/service/api_key_service.go
View file @
bb664d9b
...
@@ -158,6 +158,14 @@ func NewAPIKeyService(
...
@@ -158,6 +158,14 @@ func NewAPIKeyService(
return
svc
return
svc
}
}
func
(
s
*
APIKeyService
)
compileAPIKeyIPRules
(
apiKey
*
APIKey
)
{
if
apiKey
==
nil
{
return
}
apiKey
.
CompiledIPWhitelist
=
ip
.
CompileIPRules
(
apiKey
.
IPWhitelist
)
apiKey
.
CompiledIPBlacklist
=
ip
.
CompileIPRules
(
apiKey
.
IPBlacklist
)
}
// GenerateKey 生成随机API Key
// GenerateKey 生成随机API Key
func
(
s
*
APIKeyService
)
GenerateKey
()
(
string
,
error
)
{
func
(
s
*
APIKeyService
)
GenerateKey
()
(
string
,
error
)
{
// 生成32字节随机数据
// 生成32字节随机数据
...
@@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
...
@@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
}
s
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
s
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
s
.
compileAPIKeyIPRules
(
apiKey
)
return
apiKey
,
nil
return
apiKey
,
nil
}
}
...
@@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
...
@@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
}
s
.
compileAPIKeyIPRules
(
apiKey
)
return
apiKey
,
nil
return
apiKey
,
nil
}
}
...
@@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
...
@@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
}
s
.
compileAPIKeyIPRules
(
apiKey
)
return
apiKey
,
nil
return
apiKey
,
nil
}
}
}
}
...
@@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
...
@@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
}
s
.
compileAPIKeyIPRules
(
apiKey
)
return
apiKey
,
nil
return
apiKey
,
nil
}
}
}
else
{
}
else
{
...
@@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
...
@@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
}
s
.
compileAPIKeyIPRules
(
apiKey
)
return
apiKey
,
nil
return
apiKey
,
nil
}
}
}
}
...
@@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
...
@@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
}
apiKey
.
Key
=
key
apiKey
.
Key
=
key
s
.
compileAPIKeyIPRules
(
apiKey
)
return
apiKey
,
nil
return
apiKey
,
nil
}
}
...
@@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
...
@@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
}
s
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
s
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
s
.
compileAPIKeyIPRules
(
apiKey
)
return
apiKey
,
nil
return
apiKey
,
nil
}
}
...
...
backend/internal/service/auth_service.go
View file @
bb664d9b
...
@@ -308,6 +308,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
...
@@ -308,6 +308,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
},
nil
},
nil
}
}
// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。
// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验,
// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。
func
(
s
*
AuthService
)
VerifyTurnstileForRegister
(
ctx
context
.
Context
,
token
,
remoteIP
,
verifyCode
string
)
error
{
if
s
.
IsEmailVerifyEnabled
(
ctx
)
&&
strings
.
TrimSpace
(
verifyCode
)
!=
""
{
logger
.
LegacyPrintf
(
"service.auth"
,
"%s"
,
"[Auth] Email verify flow detected, skip duplicate Turnstile check on register"
)
return
nil
}
return
s
.
VerifyTurnstile
(
ctx
,
token
,
remoteIP
)
}
// VerifyTurnstile 验证Turnstile token
// VerifyTurnstile 验证Turnstile token
func
(
s
*
AuthService
)
VerifyTurnstile
(
ctx
context
.
Context
,
token
string
,
remoteIP
string
)
error
{
func
(
s
*
AuthService
)
VerifyTurnstile
(
ctx
context
.
Context
,
token
string
,
remoteIP
string
)
error
{
required
:=
s
.
cfg
!=
nil
&&
s
.
cfg
.
Server
.
Mode
==
"release"
&&
s
.
cfg
.
Turnstile
.
Required
required
:=
s
.
cfg
!=
nil
&&
s
.
cfg
.
Server
.
Mode
==
"release"
&&
s
.
cfg
.
Turnstile
.
Required
...
...
backend/internal/service/auth_service_turnstile_register_test.go
0 → 100644
View file @
bb664d9b
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type
turnstileVerifierSpy
struct
{
called
int
lastToken
string
result
*
TurnstileVerifyResponse
err
error
}
func
(
s
*
turnstileVerifierSpy
)
VerifyToken
(
_
context
.
Context
,
_
string
,
token
,
_
string
)
(
*
TurnstileVerifyResponse
,
error
)
{
s
.
called
++
s
.
lastToken
=
token
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
if
s
.
result
!=
nil
{
return
s
.
result
,
nil
}
return
&
TurnstileVerifyResponse
{
Success
:
true
},
nil
}
func
newAuthServiceForRegisterTurnstileTest
(
settings
map
[
string
]
string
,
verifier
TurnstileVerifier
)
*
AuthService
{
cfg
:=
&
config
.
Config
{
Server
:
config
.
ServerConfig
{
Mode
:
"release"
,
},
Turnstile
:
config
.
TurnstileConfig
{
Required
:
true
,
},
}
settingService
:=
NewSettingService
(
&
settingRepoStub
{
values
:
settings
},
cfg
)
turnstileService
:=
NewTurnstileService
(
settingService
,
verifier
)
return
NewAuthService
(
&
userRepoStub
{},
nil
,
// redeemRepo
nil
,
// refreshTokenCache
cfg
,
settingService
,
nil
,
// emailService
turnstileService
,
nil
,
// emailQueueService
nil
,
// promoService
)
}
func
TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided
(
t
*
testing
.
T
)
{
verifier
:=
&
turnstileVerifierSpy
{}
service
:=
newAuthServiceForRegisterTurnstileTest
(
map
[
string
]
string
{
SettingKeyEmailVerifyEnabled
:
"true"
,
SettingKeyTurnstileEnabled
:
"true"
,
SettingKeyTurnstileSecretKey
:
"secret"
,
SettingKeyRegistrationEnabled
:
"true"
,
},
verifier
)
err
:=
service
.
VerifyTurnstileForRegister
(
context
.
Background
(),
""
,
"127.0.0.1"
,
"123456"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
0
,
verifier
.
called
)
}
func
TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing
(
t
*
testing
.
T
)
{
verifier
:=
&
turnstileVerifierSpy
{}
service
:=
newAuthServiceForRegisterTurnstileTest
(
map
[
string
]
string
{
SettingKeyEmailVerifyEnabled
:
"true"
,
SettingKeyTurnstileEnabled
:
"true"
,
SettingKeyTurnstileSecretKey
:
"secret"
,
},
verifier
)
err
:=
service
.
VerifyTurnstileForRegister
(
context
.
Background
(),
""
,
"127.0.0.1"
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrTurnstileVerificationFailed
)
}
func
TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled
(
t
*
testing
.
T
)
{
verifier
:=
&
turnstileVerifierSpy
{}
service
:=
newAuthServiceForRegisterTurnstileTest
(
map
[
string
]
string
{
SettingKeyEmailVerifyEnabled
:
"false"
,
SettingKeyTurnstileEnabled
:
"true"
,
SettingKeyTurnstileSecretKey
:
"secret"
,
},
verifier
)
err
:=
service
.
VerifyTurnstileForRegister
(
context
.
Background
(),
"turnstile-token"
,
"127.0.0.1"
,
"123456"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
verifier
.
called
)
require
.
Equal
(
t
,
"turnstile-token"
,
verifier
.
lastToken
)
}
backend/internal/service/billing_cache_service.go
View file @
bb664d9b
...
@@ -3,6 +3,7 @@ package service
...
@@ -3,6 +3,7 @@ package service
import
(
import
(
"context"
"context"
"fmt"
"fmt"
"strconv"
"sync"
"sync"
"sync/atomic"
"sync/atomic"
"time"
"time"
...
@@ -10,6 +11,7 @@ import (
...
@@ -10,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"golang.org/x/sync/singleflight"
)
)
// 错误定义
// 错误定义
...
@@ -58,6 +60,7 @@ const (
...
@@ -58,6 +60,7 @@ const (
cacheWriteBufferSize
=
1000
// 任务队列缓冲大小
cacheWriteBufferSize
=
1000
// 任务队列缓冲大小
cacheWriteTimeout
=
2
*
time
.
Second
// 单个写入操作超时
cacheWriteTimeout
=
2
*
time
.
Second
// 单个写入操作超时
cacheWriteDropLogInterval
=
5
*
time
.
Second
// 丢弃日志节流间隔
cacheWriteDropLogInterval
=
5
*
time
.
Second
// 丢弃日志节流间隔
balanceLoadTimeout
=
3
*
time
.
Second
)
)
// cacheWriteTask 缓存写入任务
// cacheWriteTask 缓存写入任务
...
@@ -82,6 +85,9 @@ type BillingCacheService struct {
...
@@ -82,6 +85,9 @@ type BillingCacheService struct {
cacheWriteChan
chan
cacheWriteTask
cacheWriteChan
chan
cacheWriteTask
cacheWriteWg
sync
.
WaitGroup
cacheWriteWg
sync
.
WaitGroup
cacheWriteStopOnce
sync
.
Once
cacheWriteStopOnce
sync
.
Once
cacheWriteMu
sync
.
RWMutex
stopped
atomic
.
Bool
balanceLoadSF
singleflight
.
Group
// 丢弃日志节流计数器(减少高负载下日志噪音)
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount
uint64
cacheWriteDropFullCount
uint64
cacheWriteDropFullLastLog
int64
cacheWriteDropFullLastLog
int64
...
@@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
...
@@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
// Stop 关闭缓存写入工作池
// Stop 关闭缓存写入工作池
func
(
s
*
BillingCacheService
)
Stop
()
{
func
(
s
*
BillingCacheService
)
Stop
()
{
s
.
cacheWriteStopOnce
.
Do
(
func
()
{
s
.
cacheWriteStopOnce
.
Do
(
func
()
{
if
s
.
cacheWriteChan
==
nil
{
s
.
stopped
.
Store
(
true
)
s
.
cacheWriteMu
.
Lock
()
ch
:=
s
.
cacheWriteChan
if
ch
!=
nil
{
close
(
ch
)
}
s
.
cacheWriteMu
.
Unlock
()
if
ch
==
nil
{
return
return
}
}
close
(
s
.
cacheWriteChan
)
s
.
cacheWriteWg
.
Wait
()
s
.
cacheWriteWg
.
Wait
()
s
.
cacheWriteChan
=
nil
s
.
cacheWriteMu
.
Lock
()
if
s
.
cacheWriteChan
==
ch
{
s
.
cacheWriteChan
=
nil
}
s
.
cacheWriteMu
.
Unlock
()
})
})
}
}
func
(
s
*
BillingCacheService
)
startCacheWriteWorkers
()
{
func
(
s
*
BillingCacheService
)
startCacheWriteWorkers
()
{
s
.
cacheWriteChan
=
make
(
chan
cacheWriteTask
,
cacheWriteBufferSize
)
ch
:=
make
(
chan
cacheWriteTask
,
cacheWriteBufferSize
)
s
.
cacheWriteChan
=
ch
for
i
:=
0
;
i
<
cacheWriteWorkerCount
;
i
++
{
for
i
:=
0
;
i
<
cacheWriteWorkerCount
;
i
++
{
s
.
cacheWriteWg
.
Add
(
1
)
s
.
cacheWriteWg
.
Add
(
1
)
go
s
.
cacheWriteWorker
()
go
s
.
cacheWriteWorker
(
ch
)
}
}
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
func
(
s
*
BillingCacheService
)
enqueueCacheWrite
(
task
cacheWriteTask
)
(
enqueued
bool
)
{
func
(
s
*
BillingCacheService
)
enqueueCacheWrite
(
task
cacheWriteTask
)
(
enqueued
bool
)
{
if
s
.
stopped
.
Load
()
{
s
.
logCacheWriteDrop
(
task
,
"closed"
)
return
false
}
s
.
cacheWriteMu
.
RLock
()
defer
s
.
cacheWriteMu
.
RUnlock
()
if
s
.
cacheWriteChan
==
nil
{
if
s
.
cacheWriteChan
==
nil
{
s
.
logCacheWriteDrop
(
task
,
"closed"
)
return
false
return
false
}
}
defer
func
()
{
if
recovered
:=
recover
();
recovered
!=
nil
{
// 队列已关闭时可能触发 panic,记录后静默失败。
s
.
logCacheWriteDrop
(
task
,
"closed"
)
enqueued
=
false
}
}()
select
{
select
{
case
s
.
cacheWriteChan
<-
task
:
case
s
.
cacheWriteChan
<-
task
:
return
true
return
true
...
@@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b
...
@@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b
}
}
}
}
func
(
s
*
BillingCacheService
)
cacheWriteWorker
()
{
func
(
s
*
BillingCacheService
)
cacheWriteWorker
(
ch
<-
chan
cacheWriteTask
)
{
defer
s
.
cacheWriteWg
.
Done
()
defer
s
.
cacheWriteWg
.
Done
()
for
task
:=
range
s
.
cacheWriteChan
{
for
task
:=
range
ch
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
cacheWriteTimeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
cacheWriteTimeout
)
switch
task
.
kind
{
switch
task
.
kind
{
case
cacheWriteSetBalance
:
case
cacheWriteSetBalance
:
...
@@ -243,19 +266,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
...
@@ -243,19 +266,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
return
balance
,
nil
return
balance
,
nil
}
}
// 缓存未命中,从数据库读取
// 缓存未命中:singleflight 合并同一 userID 的并发回源请求。
balance
,
err
=
s
.
getUserBalanceFromDB
(
ctx
,
userID
)
value
,
err
,
_
:=
s
.
balanceLoadSF
.
Do
(
strconv
.
FormatInt
(
userID
,
10
),
func
()
(
any
,
error
)
{
loadCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
balanceLoadTimeout
)
defer
cancel
()
balance
,
err
:=
s
.
getUserBalanceFromDB
(
loadCtx
,
userID
)
if
err
!=
nil
{
return
nil
,
err
}
// 异步建立缓存
_
=
s
.
enqueueCacheWrite
(
cacheWriteTask
{
kind
:
cacheWriteSetBalance
,
userID
:
userID
,
balance
:
balance
,
})
return
balance
,
nil
})
if
err
!=
nil
{
if
err
!=
nil
{
return
0
,
err
return
0
,
err
}
}
balance
,
ok
:=
value
.
(
float64
)
// 异步建立缓存
if
!
ok
{
_
=
s
.
enqueueCacheWrite
(
cacheWriteTask
{
return
0
,
fmt
.
Errorf
(
"unexpected balance type: %T"
,
value
)
kind
:
cacheWriteSetBalance
,
}
userID
:
userID
,
balance
:
balance
,
})
return
balance
,
nil
return
balance
,
nil
}
}
...
...
backend/internal/service/billing_cache_service_singleflight_test.go
0 → 100644
View file @
bb664d9b
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type
billingCacheMissStub
struct
{
setBalanceCalls
atomic
.
Int64
}
func
(
s
*
billingCacheMissStub
)
GetUserBalance
(
ctx
context
.
Context
,
userID
int64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"cache miss"
)
}
func
(
s
*
billingCacheMissStub
)
SetUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
)
error
{
s
.
setBalanceCalls
.
Add
(
1
)
return
nil
}
func
(
s
*
billingCacheMissStub
)
DeductUserBalance
(
ctx
context
.
Context
,
userID
int64
,
amount
float64
)
error
{
return
nil
}
func
(
s
*
billingCacheMissStub
)
InvalidateUserBalance
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
nil
}
func
(
s
*
billingCacheMissStub
)
GetSubscriptionCache
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
SubscriptionCacheData
,
error
)
{
return
nil
,
errors
.
New
(
"cache miss"
)
}
func
(
s
*
billingCacheMissStub
)
SetSubscriptionCache
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
data
*
SubscriptionCacheData
)
error
{
return
nil
}
func
(
s
*
billingCacheMissStub
)
UpdateSubscriptionUsage
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
cost
float64
)
error
{
return
nil
}
func
(
s
*
billingCacheMissStub
)
InvalidateSubscriptionCache
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
error
{
return
nil
}
type
balanceLoadUserRepoStub
struct
{
mockUserRepo
calls
atomic
.
Int64
delay
time
.
Duration
balance
float64
}
func
(
s
*
balanceLoadUserRepoStub
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
User
,
error
)
{
s
.
calls
.
Add
(
1
)
if
s
.
delay
>
0
{
select
{
case
<-
time
.
After
(
s
.
delay
)
:
case
<-
ctx
.
Done
()
:
return
nil
,
ctx
.
Err
()
}
}
return
&
User
{
ID
:
id
,
Balance
:
s
.
balance
},
nil
}
func
TestBillingCacheServiceGetUserBalance_Singleflight
(
t
*
testing
.
T
)
{
cache
:=
&
billingCacheMissStub
{}
userRepo
:=
&
balanceLoadUserRepoStub
{
delay
:
80
*
time
.
Millisecond
,
balance
:
12.34
,
}
svc
:=
NewBillingCacheService
(
cache
,
userRepo
,
nil
,
&
config
.
Config
{})
t
.
Cleanup
(
svc
.
Stop
)
const
goroutines
=
16
start
:=
make
(
chan
struct
{})
var
wg
sync
.
WaitGroup
errCh
:=
make
(
chan
error
,
goroutines
)
balCh
:=
make
(
chan
float64
,
goroutines
)
for
i
:=
0
;
i
<
goroutines
;
i
++
{
wg
.
Add
(
1
)
go
func
()
{
defer
wg
.
Done
()
<-
start
bal
,
err
:=
svc
.
GetUserBalance
(
context
.
Background
(),
99
)
errCh
<-
err
balCh
<-
bal
}()
}
close
(
start
)
wg
.
Wait
()
close
(
errCh
)
close
(
balCh
)
for
err
:=
range
errCh
{
require
.
NoError
(
t
,
err
)
}
for
bal
:=
range
balCh
{
require
.
Equal
(
t
,
12.34
,
bal
)
}
require
.
Equal
(
t
,
int64
(
1
),
userRepo
.
calls
.
Load
(),
"并发穿透应被 singleflight 合并"
)
require
.
Eventually
(
t
,
func
()
bool
{
return
cache
.
setBalanceCalls
.
Load
()
>=
1
},
time
.
Second
,
10
*
time
.
Millisecond
)
}
backend/internal/service/billing_cache_service_test.go
View file @
bb664d9b
...
@@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
...
@@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
return
atomic
.
LoadInt64
(
&
cache
.
subscriptionUpdates
)
>
0
return
atomic
.
LoadInt64
(
&
cache
.
subscriptionUpdates
)
>
0
},
2
*
time
.
Second
,
10
*
time
.
Millisecond
)
},
2
*
time
.
Second
,
10
*
time
.
Millisecond
)
}
}
func
TestBillingCacheServiceEnqueueAfterStopReturnsFalse
(
t
*
testing
.
T
)
{
cache
:=
&
billingCacheWorkerStub
{}
svc
:=
NewBillingCacheService
(
cache
,
nil
,
nil
,
&
config
.
Config
{})
svc
.
Stop
()
enqueued
:=
svc
.
enqueueCacheWrite
(
cacheWriteTask
{
kind
:
cacheWriteDeductBalance
,
userID
:
1
,
amount
:
1
,
})
require
.
False
(
t
,
enqueued
)
}
backend/internal/service/billing_service_image_test.go
View file @
bb664d9b
...
@@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) {
...
@@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) {
// 费率倍数 1.5x
// 费率倍数 1.5x
cost
:=
svc
.
CalculateImageCost
(
"gemini-3-pro-image"
,
"2K"
,
1
,
nil
,
1.5
)
cost
:=
svc
.
CalculateImageCost
(
"gemini-3-pro-image"
,
"2K"
,
1
,
nil
,
1.5
)
require
.
InDelta
(
t
,
0.201
,
cost
.
TotalCost
,
0.0001
)
// TotalCost = 0.134 * 1.5
require
.
InDelta
(
t
,
0.201
,
cost
.
TotalCost
,
0.0001
)
// TotalCost = 0.134 * 1.5
require
.
InDelta
(
t
,
0.3015
,
cost
.
ActualCost
,
0.0001
)
// ActualCost = 0.201 * 1.5
require
.
InDelta
(
t
,
0.3015
,
cost
.
ActualCost
,
0.0001
)
// ActualCost = 0.201 * 1.5
// 费率倍数 2.0x
// 费率倍数 2.0x
...
...
backend/internal/service/claude_code_validator.go
View file @
bb664d9b
...
@@ -78,7 +78,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
...
@@ -78,7 +78,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
if
isMaxTokensOneHaiku
,
ok
:=
r
.
Context
()
.
Value
(
ctxkey
.
IsMaxTokensOneHaikuRequest
)
.
(
bool
);
ok
&&
isMaxTokensOneHaiku
{
if
isMaxTokensOneHaiku
,
ok
:=
IsMaxTokensOneHaikuRequest
FromContext
(
r
.
Context
()
);
ok
&&
isMaxTokensOneHaiku
{
return
true
// 绕过 system prompt 检查,UA 已在 Step 1 验证
return
true
// 绕过 system prompt 检查,UA 已在 Step 1 验证
}
}
...
...
backend/internal/service/concurrency_service.go
View file @
bb664d9b
...
@@ -3,8 +3,10 @@ package service
...
@@ -3,8 +3,10 @@ package service
import
(
import
(
"context"
"context"
"crypto/rand"
"crypto/rand"
"encoding/hex"
"encoding/binary"
"fmt"
"os"
"strconv"
"sync/atomic"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
...
@@ -18,6 +20,7 @@ type ConcurrencyCache interface {
...
@@ -18,6 +20,7 @@ type ConcurrencyCache interface {
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
// 账号等待队列(账号级)
// 账号等待队列(账号级)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
...
@@ -42,15 +45,25 @@ type ConcurrencyCache interface {
...
@@ -42,15 +45,25 @@ type ConcurrencyCache interface {
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
}
}
// generateRequestID generates a unique request ID for concurrency slot tracking
var
(
// Uses 8 random bytes (16 hex chars) for uniqueness
requestIDPrefix
=
initRequestIDPrefix
()
func
generateRequestID
()
string
{
requestIDCounter
atomic
.
Uint64
)
func
initRequestIDPrefix
()
string
{
b
:=
make
([]
byte
,
8
)
b
:=
make
([]
byte
,
8
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
if
_
,
err
:=
rand
.
Read
(
b
);
err
==
nil
{
// Fallback to nanosecond timestamp (extremely rare case)
return
"r"
+
strconv
.
FormatUint
(
binary
.
BigEndian
.
Uint64
(
b
),
36
)
return
fmt
.
Sprintf
(
"%x"
,
time
.
Now
()
.
UnixNano
())
}
}
return
hex
.
EncodeToString
(
b
)
fallback
:=
uint64
(
time
.
Now
()
.
UnixNano
())
^
(
uint64
(
os
.
Getpid
())
<<
16
)
return
"r"
+
strconv
.
FormatUint
(
fallback
,
36
)
}
// generateRequestID generates a unique request ID for concurrency slot tracking.
// Format: {process_random_prefix}-{base36_counter}
func
generateRequestID
()
string
{
seq
:=
requestIDCounter
.
Add
(
1
)
return
requestIDPrefix
+
"-"
+
strconv
.
FormatUint
(
seq
,
36
)
}
}
const
(
const
(
...
@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
...
@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
// Returns a map of accountID -> current concurrency count
func
(
s
*
ConcurrencyService
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
func
(
s
*
ConcurrencyService
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
result
:=
make
(
map
[
int64
]
int
)
if
len
(
accountIDs
)
==
0
{
return
map
[
int64
]
int
{},
nil
for
_
,
accountID
:=
range
accountIDs
{
}
count
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
ctx
,
accountID
)
if
s
.
cache
==
nil
{
if
err
!=
nil
{
result
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
// If key doesn't exist in Redis,
count
i
s
0
for
_
,
accountID
:=
range
ac
count
ID
s
{
count
=
0
result
[
ac
count
ID
]
=
0
}
}
re
sult
[
accountID
]
=
count
re
turn
result
,
nil
}
}
return
s
.
cache
.
GetAccountConcurrencyBatch
(
ctx
,
accountIDs
)
return
result
,
nil
}
}
backend/internal/service/concurrency_service_test.go
View file @
bb664d9b
...
@@ -5,6 +5,8 @@ package service
...
@@ -5,6 +5,8 @@ package service
import
(
import
(
"context"
"context"
"errors"
"errors"
"strconv"
"strings"
"testing"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
...
@@ -12,20 +14,20 @@ import (
...
@@ -12,20 +14,20 @@ import (
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
type
stubConcurrencyCacheForTest
struct
{
type
stubConcurrencyCacheForTest
struct
{
acquireResult
bool
acquireResult
bool
acquireErr
error
acquireErr
error
releaseErr
error
releaseErr
error
concurrency
int
concurrency
int
concurrencyErr
error
concurrencyErr
error
waitAllowed
bool
waitAllowed
bool
waitErr
error
waitErr
error
waitCount
int
waitCount
int
waitCountErr
error
waitCountErr
error
loadBatch
map
[
int64
]
*
AccountLoadInfo
loadBatch
map
[
int64
]
*
AccountLoadInfo
loadBatchErr
error
loadBatchErr
error
usersLoadBatch
map
[
int64
]
*
UserLoadInfo
usersLoadBatch
map
[
int64
]
*
UserLoadInfo
usersLoadErr
error
usersLoadErr
error
cleanupErr
error
cleanupErr
error
// 记录调用
// 记录调用
releasedAccountIDs
[]
int64
releasedAccountIDs
[]
int64
...
@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco
...
@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco
func
(
c
*
stubConcurrencyCacheForTest
)
GetAccountConcurrency
(
_
context
.
Context
,
_
int64
)
(
int
,
error
)
{
func
(
c
*
stubConcurrencyCacheForTest
)
GetAccountConcurrency
(
_
context
.
Context
,
_
int64
)
(
int
,
error
)
{
return
c
.
concurrency
,
c
.
concurrencyErr
return
c
.
concurrency
,
c
.
concurrencyErr
}
}
func
(
c
*
stubConcurrencyCacheForTest
)
GetAccountConcurrencyBatch
(
_
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
result
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
if
c
.
concurrencyErr
!=
nil
{
return
nil
,
c
.
concurrencyErr
}
result
[
accountID
]
=
c
.
concurrency
}
return
result
,
nil
}
func
(
c
*
stubConcurrencyCacheForTest
)
IncrementAccountWaitCount
(
_
context
.
Context
,
_
int64
,
_
int
)
(
bool
,
error
)
{
func
(
c
*
stubConcurrencyCacheForTest
)
IncrementAccountWaitCount
(
_
context
.
Context
,
_
int64
,
_
int
)
(
bool
,
error
)
{
return
c
.
waitAllowed
,
c
.
waitErr
return
c
.
waitAllowed
,
c
.
waitErr
}
}
...
@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
...
@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
require
.
True
(
t
,
result
.
Acquired
)
require
.
True
(
t
,
result
.
Acquired
)
}
}
func
TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter
(
t
*
testing
.
T
)
{
id1
:=
generateRequestID
()
id2
:=
generateRequestID
()
require
.
NotEmpty
(
t
,
id1
)
require
.
NotEmpty
(
t
,
id2
)
p1
:=
strings
.
Split
(
id1
,
"-"
)
p2
:=
strings
.
Split
(
id2
,
"-"
)
require
.
Len
(
t
,
p1
,
2
)
require
.
Len
(
t
,
p2
,
2
)
require
.
Equal
(
t
,
p1
[
0
],
p2
[
0
],
"同一进程前缀应保持一致"
)
n1
,
err
:=
strconv
.
ParseUint
(
p1
[
1
],
36
,
64
)
require
.
NoError
(
t
,
err
)
n2
,
err
:=
strconv
.
ParseUint
(
p2
[
1
],
36
,
64
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
n1
+
1
,
n2
,
"计数器应单调递增"
)
}
func
TestGetAccountsLoadBatch_ReturnsCorrectData
(
t
*
testing
.
T
)
{
func
TestGetAccountsLoadBatch_ReturnsCorrectData
(
t
*
testing
.
T
)
{
expected
:=
map
[
int64
]
*
AccountLoadInfo
{
expected
:=
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
CurrentConcurrency
:
3
,
WaitingCount
:
0
,
LoadRate
:
60
},
1
:
{
AccountID
:
1
,
CurrentConcurrency
:
3
,
WaitingCount
:
0
,
LoadRate
:
60
},
...
...
backend/internal/service/dashboard_service.go
View file @
bb664d9b
...
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
...
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return
stats
,
nil
return
stats
,
nil
}
}
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
,
billingType
)
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
requestType
,
stream
,
billingType
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get usage trend with filters: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get usage trend with filters: %w"
,
err
)
}
}
return
trend
,
nil
return
trend
,
nil
}
}
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
stream
,
billingType
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
requestType
,
stream
,
billingType
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get model stats with filters: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get model stats with filters: %w"
,
err
)
}
}
...
...
backend/internal/service/data_management_grpc.go
0 → 100644
View file @
bb664d9b
package
service
import
"context"
type
DataManagementPostgresConfig
struct
{
Host
string
`json:"host"`
Port
int32
`json:"port"`
User
string
`json:"user"`
Password
string
`json:"password,omitempty"`
PasswordConfigured
bool
`json:"password_configured"`
Database
string
`json:"database"`
SSLMode
string
`json:"ssl_mode"`
ContainerName
string
`json:"container_name"`
}
type
DataManagementRedisConfig
struct
{
Addr
string
`json:"addr"`
Username
string
`json:"username"`
Password
string
`json:"password,omitempty"`
PasswordConfigured
bool
`json:"password_configured"`
DB
int32
`json:"db"`
ContainerName
string
`json:"container_name"`
}
type
DataManagementS3Config
struct
{
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key,omitempty"`
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
UseSSL
bool
`json:"use_ssl"`
}
type
DataManagementConfig
struct
{
SourceMode
string
`json:"source_mode"`
BackupRoot
string
`json:"backup_root"`
SQLitePath
string
`json:"sqlite_path,omitempty"`
RetentionDays
int32
`json:"retention_days"`
KeepLast
int32
`json:"keep_last"`
ActivePostgresID
string
`json:"active_postgres_profile_id"`
ActiveRedisID
string
`json:"active_redis_profile_id"`
Postgres
DataManagementPostgresConfig
`json:"postgres"`
Redis
DataManagementRedisConfig
`json:"redis"`
S3
DataManagementS3Config
`json:"s3"`
ActiveS3ProfileID
string
`json:"active_s3_profile_id"`
}
type
DataManagementTestS3Result
struct
{
OK
bool
`json:"ok"`
Message
string
`json:"message"`
}
type
DataManagementCreateBackupJobInput
struct
{
BackupType
string
UploadToS3
bool
TriggeredBy
string
IdempotencyKey
string
S3ProfileID
string
PostgresID
string
RedisID
string
}
type
DataManagementListBackupJobsInput
struct
{
PageSize
int32
PageToken
string
Status
string
BackupType
string
}
type
DataManagementArtifactInfo
struct
{
LocalPath
string
`json:"local_path"`
SizeBytes
int64
`json:"size_bytes"`
SHA256
string
`json:"sha256"`
}
type
DataManagementS3ObjectInfo
struct
{
Bucket
string
`json:"bucket"`
Key
string
`json:"key"`
ETag
string
`json:"etag"`
}
type
DataManagementBackupJob
struct
{
JobID
string
`json:"job_id"`
BackupType
string
`json:"backup_type"`
Status
string
`json:"status"`
TriggeredBy
string
`json:"triggered_by"`
IdempotencyKey
string
`json:"idempotency_key,omitempty"`
UploadToS3
bool
`json:"upload_to_s3"`
S3ProfileID
string
`json:"s3_profile_id,omitempty"`
PostgresID
string
`json:"postgres_profile_id,omitempty"`
RedisID
string
`json:"redis_profile_id,omitempty"`
StartedAt
string
`json:"started_at,omitempty"`
FinishedAt
string
`json:"finished_at,omitempty"`
ErrorMessage
string
`json:"error_message,omitempty"`
Artifact
DataManagementArtifactInfo
`json:"artifact"`
S3Object
DataManagementS3ObjectInfo
`json:"s3"`
}
type
DataManagementSourceProfile
struct
{
SourceType
string
`json:"source_type"`
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
IsActive
bool
`json:"is_active"`
Config
DataManagementSourceConfig
`json:"config"`
PasswordConfigured
bool
`json:"password_configured"`
CreatedAt
string
`json:"created_at,omitempty"`
UpdatedAt
string
`json:"updated_at,omitempty"`
}
type
DataManagementSourceConfig
struct
{
Host
string
`json:"host"`
Port
int32
`json:"port"`
User
string
`json:"user"`
Password
string
`json:"password,omitempty"`
Database
string
`json:"database"`
SSLMode
string
`json:"ssl_mode"`
Addr
string
`json:"addr"`
Username
string
`json:"username"`
DB
int32
`json:"db"`
ContainerName
string
`json:"container_name"`
}
type
DataManagementCreateSourceProfileInput
struct
{
SourceType
string
ProfileID
string
Name
string
Config
DataManagementSourceConfig
SetActive
bool
}
type
DataManagementUpdateSourceProfileInput
struct
{
SourceType
string
ProfileID
string
Name
string
Config
DataManagementSourceConfig
}
type
DataManagementS3Profile
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
IsActive
bool
`json:"is_active"`
S3
DataManagementS3Config
`json:"s3"`
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
CreatedAt
string
`json:"created_at,omitempty"`
UpdatedAt
string
`json:"updated_at,omitempty"`
}
type
DataManagementCreateS3ProfileInput
struct
{
ProfileID
string
Name
string
S3
DataManagementS3Config
SetActive
bool
}
type
DataManagementUpdateS3ProfileInput
struct
{
ProfileID
string
Name
string
S3
DataManagementS3Config
}
type
DataManagementListBackupJobsResult
struct
{
Items
[]
DataManagementBackupJob
`json:"items"`
NextPageToken
string
`json:"next_page_token,omitempty"`
}
func
(
s
*
DataManagementService
)
GetConfig
(
ctx
context
.
Context
)
(
DataManagementConfig
,
error
)
{
_
=
ctx
return
DataManagementConfig
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
UpdateConfig
(
ctx
context
.
Context
,
cfg
DataManagementConfig
)
(
DataManagementConfig
,
error
)
{
_
,
_
=
ctx
,
cfg
return
DataManagementConfig
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
ListSourceProfiles
(
ctx
context
.
Context
,
sourceType
string
)
([]
DataManagementSourceProfile
,
error
)
{
_
,
_
=
ctx
,
sourceType
return
nil
,
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
CreateSourceProfile
(
ctx
context
.
Context
,
input
DataManagementCreateSourceProfileInput
)
(
DataManagementSourceProfile
,
error
)
{
_
,
_
=
ctx
,
input
return
DataManagementSourceProfile
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
UpdateSourceProfile
(
ctx
context
.
Context
,
input
DataManagementUpdateSourceProfileInput
)
(
DataManagementSourceProfile
,
error
)
{
_
,
_
=
ctx
,
input
return
DataManagementSourceProfile
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
DeleteSourceProfile
(
ctx
context
.
Context
,
sourceType
,
profileID
string
)
error
{
_
,
_
,
_
=
ctx
,
sourceType
,
profileID
return
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
SetActiveSourceProfile
(
ctx
context
.
Context
,
sourceType
,
profileID
string
)
(
DataManagementSourceProfile
,
error
)
{
_
,
_
,
_
=
ctx
,
sourceType
,
profileID
return
DataManagementSourceProfile
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
ValidateS3
(
ctx
context
.
Context
,
cfg
DataManagementS3Config
)
(
DataManagementTestS3Result
,
error
)
{
_
,
_
=
ctx
,
cfg
return
DataManagementTestS3Result
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
ListS3Profiles
(
ctx
context
.
Context
)
([]
DataManagementS3Profile
,
error
)
{
_
=
ctx
return
nil
,
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
CreateS3Profile
(
ctx
context
.
Context
,
input
DataManagementCreateS3ProfileInput
)
(
DataManagementS3Profile
,
error
)
{
_
,
_
=
ctx
,
input
return
DataManagementS3Profile
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
UpdateS3Profile
(
ctx
context
.
Context
,
input
DataManagementUpdateS3ProfileInput
)
(
DataManagementS3Profile
,
error
)
{
_
,
_
=
ctx
,
input
return
DataManagementS3Profile
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
DeleteS3Profile
(
ctx
context
.
Context
,
profileID
string
)
error
{
_
,
_
=
ctx
,
profileID
return
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
SetActiveS3Profile
(
ctx
context
.
Context
,
profileID
string
)
(
DataManagementS3Profile
,
error
)
{
_
,
_
=
ctx
,
profileID
return
DataManagementS3Profile
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
CreateBackupJob
(
ctx
context
.
Context
,
input
DataManagementCreateBackupJobInput
)
(
DataManagementBackupJob
,
error
)
{
_
,
_
=
ctx
,
input
return
DataManagementBackupJob
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
ListBackupJobs
(
ctx
context
.
Context
,
input
DataManagementListBackupJobsInput
)
(
DataManagementListBackupJobsResult
,
error
)
{
_
,
_
=
ctx
,
input
return
DataManagementListBackupJobsResult
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
GetBackupJob
(
ctx
context
.
Context
,
jobID
string
)
(
DataManagementBackupJob
,
error
)
{
_
,
_
=
ctx
,
jobID
return
DataManagementBackupJob
{},
s
.
deprecatedError
()
}
func
(
s
*
DataManagementService
)
deprecatedError
()
error
{
return
ErrDataManagementDeprecated
.
WithMetadata
(
map
[
string
]
string
{
"socket_path"
:
s
.
SocketPath
()})
}
backend/internal/service/data_management_grpc_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"path/filepath"
"testing"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func
TestDataManagementService_DeprecatedRPCMethods
(
t
*
testing
.
T
)
{
t
.
Parallel
()
socketPath
:=
filepath
.
Join
(
t
.
TempDir
(),
"datamanagement.sock"
)
svc
:=
NewDataManagementServiceWithOptions
(
socketPath
,
0
)
_
,
err
:=
svc
.
GetConfig
(
context
.
Background
())
assertDeprecatedDataManagementError
(
t
,
err
,
socketPath
)
_
,
err
=
svc
.
CreateBackupJob
(
context
.
Background
(),
DataManagementCreateBackupJobInput
{
BackupType
:
"full"
})
assertDeprecatedDataManagementError
(
t
,
err
,
socketPath
)
err
=
svc
.
DeleteS3Profile
(
context
.
Background
(),
"s3-default"
)
assertDeprecatedDataManagementError
(
t
,
err
,
socketPath
)
}
func
assertDeprecatedDataManagementError
(
t
*
testing
.
T
,
err
error
,
socketPath
string
)
{
t
.
Helper
()
require
.
Error
(
t
,
err
)
statusCode
,
status
:=
infraerrors
.
ToHTTP
(
err
)
require
.
Equal
(
t
,
503
,
statusCode
)
require
.
Equal
(
t
,
DataManagementDeprecatedReason
,
status
.
Reason
)
require
.
Equal
(
t
,
socketPath
,
status
.
Metadata
[
"socket_path"
])
}
backend/internal/service/data_management_service.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"strings"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const
(
DefaultDataManagementAgentSocketPath
=
"/tmp/sub2api-datamanagement.sock"
LegacyBackupAgentSocketPath
=
"/tmp/sub2api-backup.sock"
DataManagementDeprecatedReason
=
"DATA_MANAGEMENT_DEPRECATED"
DataManagementAgentSocketMissingReason
=
"DATA_MANAGEMENT_AGENT_SOCKET_MISSING"
DataManagementAgentUnavailableReason
=
"DATA_MANAGEMENT_AGENT_UNAVAILABLE"
// Deprecated: keep old names for compatibility.
DefaultBackupAgentSocketPath
=
DefaultDataManagementAgentSocketPath
BackupAgentSocketMissingReason
=
DataManagementAgentSocketMissingReason
BackupAgentUnavailableReason
=
DataManagementAgentUnavailableReason
)
var
(
ErrDataManagementDeprecated
=
infraerrors
.
ServiceUnavailable
(
DataManagementDeprecatedReason
,
"data management feature is deprecated"
,
)
ErrDataManagementAgentSocketMissing
=
infraerrors
.
ServiceUnavailable
(
DataManagementAgentSocketMissingReason
,
"data management agent socket is missing"
,
)
ErrDataManagementAgentUnavailable
=
infraerrors
.
ServiceUnavailable
(
DataManagementAgentUnavailableReason
,
"data management agent is unavailable"
,
)
// Deprecated: keep old names for compatibility.
ErrBackupAgentSocketMissing
=
ErrDataManagementAgentSocketMissing
ErrBackupAgentUnavailable
=
ErrDataManagementAgentUnavailable
)
type
DataManagementAgentHealth
struct
{
Enabled
bool
Reason
string
SocketPath
string
Agent
*
DataManagementAgentInfo
}
type
DataManagementAgentInfo
struct
{
Status
string
Version
string
UptimeSeconds
int64
}
type
DataManagementService
struct
{
socketPath
string
dialTimeout
time
.
Duration
}
func
NewDataManagementService
()
*
DataManagementService
{
return
NewDataManagementServiceWithOptions
(
DefaultDataManagementAgentSocketPath
,
500
*
time
.
Millisecond
)
}
func
NewDataManagementServiceWithOptions
(
socketPath
string
,
dialTimeout
time
.
Duration
)
*
DataManagementService
{
path
:=
strings
.
TrimSpace
(
socketPath
)
if
path
==
""
{
path
=
DefaultDataManagementAgentSocketPath
}
if
dialTimeout
<=
0
{
dialTimeout
=
500
*
time
.
Millisecond
}
return
&
DataManagementService
{
socketPath
:
path
,
dialTimeout
:
dialTimeout
,
}
}
func
(
s
*
DataManagementService
)
SocketPath
()
string
{
if
s
==
nil
||
strings
.
TrimSpace
(
s
.
socketPath
)
==
""
{
return
DefaultDataManagementAgentSocketPath
}
return
s
.
socketPath
}
func
(
s
*
DataManagementService
)
GetAgentHealth
(
ctx
context
.
Context
)
DataManagementAgentHealth
{
_
=
ctx
return
DataManagementAgentHealth
{
Enabled
:
false
,
Reason
:
DataManagementDeprecatedReason
,
SocketPath
:
s
.
SocketPath
(),
}
}
func
(
s
*
DataManagementService
)
EnsureAgentEnabled
(
ctx
context
.
Context
)
error
{
_
=
ctx
return
ErrDataManagementDeprecated
.
WithMetadata
(
map
[
string
]
string
{
"socket_path"
:
s
.
SocketPath
()})
}
backend/internal/service/data_management_service_test.go
0 → 100644
View file @
bb664d9b
package
service
import
(
"context"
"path/filepath"
"testing"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func
TestDataManagementService_GetAgentHealth_Deprecated
(
t
*
testing
.
T
)
{
t
.
Parallel
()
socketPath
:=
filepath
.
Join
(
t
.
TempDir
(),
"unused.sock"
)
svc
:=
NewDataManagementServiceWithOptions
(
socketPath
,
0
)
health
:=
svc
.
GetAgentHealth
(
context
.
Background
())
require
.
False
(
t
,
health
.
Enabled
)
require
.
Equal
(
t
,
DataManagementDeprecatedReason
,
health
.
Reason
)
require
.
Equal
(
t
,
socketPath
,
health
.
SocketPath
)
require
.
Nil
(
t
,
health
.
Agent
)
}
func
TestDataManagementService_EnsureAgentEnabled_Deprecated
(
t
*
testing
.
T
)
{
t
.
Parallel
()
socketPath
:=
filepath
.
Join
(
t
.
TempDir
(),
"unused.sock"
)
svc
:=
NewDataManagementServiceWithOptions
(
socketPath
,
100
)
err
:=
svc
.
EnsureAgentEnabled
(
context
.
Background
())
require
.
Error
(
t
,
err
)
statusCode
,
status
:=
infraerrors
.
ToHTTP
(
err
)
require
.
Equal
(
t
,
503
,
statusCode
)
require
.
Equal
(
t
,
DataManagementDeprecatedReason
,
status
.
Reason
)
require
.
Equal
(
t
,
socketPath
,
status
.
Metadata
[
"socket_path"
])
}
backend/internal/service/domain_constants.go
View file @
bb664d9b
...
@@ -104,6 +104,7 @@ const (
...
@@ -104,6 +104,7 @@ const (
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
// OEM设置
// OEM设置
SettingKeySoraClientEnabled
=
"sora_client_enabled"
// 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
...
@@ -170,6 +171,27 @@ const (
...
@@ -170,6 +171,27 @@ const (
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings
=
"stream_timeout_settings"
SettingKeyStreamTimeoutSettings
=
"stream_timeout_settings"
// =========================
// Sora S3 存储配置
// =========================
SettingKeySoraS3Enabled
=
"sora_s3_enabled"
// 是否启用 Sora S3 存储
SettingKeySoraS3Endpoint
=
"sora_s3_endpoint"
// S3 端点地址
SettingKeySoraS3Region
=
"sora_s3_region"
// S3 区域
SettingKeySoraS3Bucket
=
"sora_s3_bucket"
// S3 存储桶名称
SettingKeySoraS3AccessKeyID
=
"sora_s3_access_key_id"
// S3 Access Key ID
SettingKeySoraS3SecretAccessKey
=
"sora_s3_secret_access_key"
// S3 Secret Access Key(加密存储)
SettingKeySoraS3Prefix
=
"sora_s3_prefix"
// S3 对象键前缀
SettingKeySoraS3ForcePathStyle
=
"sora_s3_force_path_style"
// 是否强制 Path Style(兼容 MinIO 等)
SettingKeySoraS3CDNURL
=
"sora_s3_cdn_url"
// CDN 加速 URL(可选)
SettingKeySoraS3Profiles
=
"sora_s3_profiles"
// Sora S3 多配置(JSON)
// =========================
// Sora 用户存储配额
// =========================
SettingKeySoraDefaultStorageQuotaBytes
=
"sora_default_storage_quota_bytes"
// 新用户默认 Sora 存储配额(字节)
)
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
...
...
backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
View file @
bb664d9b
...
@@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE
...
@@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE
wantPassthrough
:
true
,
wantPassthrough
:
true
,
},
},
{
{
name
:
"404 generic not found pass
es
through
as 404
"
,
name
:
"404 generic not found
does not
passthrough"
,
statusCode
:
http
.
StatusNotFound
,
statusCode
:
http
.
StatusNotFound
,
respBody
:
`{"error":{"message":"resource not found","type":"not_found_error"}}`
,
respBody
:
`{"error":{"message":"resource not found","type":"not_found_error"}}`
,
wantPassthrough
:
tru
e
,
wantPassthrough
:
fals
e
,
},
},
{
{
name
:
"400 Invalid URL does not passthrough"
,
name
:
"400 Invalid URL does not passthrough"
,
...
...
backend/internal/service/gateway_beta_test.go
View file @
bb664d9b
...
@@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) {
...
@@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) {
require
.
Contains
(
t
,
extended
,
claude
.
BetaClaudeCode
)
require
.
Contains
(
t
,
extended
,
claude
.
BetaClaudeCode
)
require
.
Len
(
t
,
extended
,
len
(
claude
.
DroppedBetas
)
+
1
)
require
.
Len
(
t
,
extended
,
len
(
claude
.
DroppedBetas
)
+
1
)
}
}
func
TestBuildBetaTokenSet
(
t
*
testing
.
T
)
{
got
:=
buildBetaTokenSet
([]
string
{
"foo"
,
""
,
"bar"
,
"foo"
})
require
.
Len
(
t
,
got
,
2
)
require
.
Contains
(
t
,
got
,
"foo"
)
require
.
Contains
(
t
,
got
,
"bar"
)
require
.
NotContains
(
t
,
got
,
""
)
empty
:=
buildBetaTokenSet
(
nil
)
require
.
Empty
(
t
,
empty
)
}
func
TestStripBetaTokensWithSet_EmptyDropSet
(
t
*
testing
.
T
)
{
header
:=
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
got
:=
stripBetaTokensWithSet
(
header
,
map
[
string
]
struct
{}{})
require
.
Equal
(
t
,
header
,
got
)
}
func
TestIsCountTokensUnsupported404
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
statusCode
int
body
string
want
bool
}{
{
name
:
"exact endpoint not found"
,
statusCode
:
404
,
body
:
`{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`
,
want
:
true
,
},
{
name
:
"contains count_tokens and not found"
,
statusCode
:
404
,
body
:
`{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`
,
want
:
true
,
},
{
name
:
"generic 404"
,
statusCode
:
404
,
body
:
`{"error":{"message":"resource not found","type":"not_found_error"}}`
,
want
:
false
,
},
{
name
:
"404 with empty error message"
,
statusCode
:
404
,
body
:
`{"error":{"message":"","type":"not_found_error"}}`
,
want
:
false
,
},
{
name
:
"non-404 status"
,
statusCode
:
400
,
body
:
`{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`
,
want
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
isCountTokensUnsupported404
(
tt
.
statusCode
,
[]
byte
(
tt
.
body
))
require
.
Equal
(
t
,
tt
.
want
,
got
)
})
}
}
backend/internal/service/gateway_multiplatform_test.go
View file @
bb664d9b
...
@@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun
...
@@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun
return
0
,
nil
return
0
,
nil
}
}
func
(
m
*
mockConcurrencyCache
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
result
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
result
[
accountID
]
=
0
}
return
result
,
nil
}
func
(
m
*
mockConcurrencyCache
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
func
(
m
*
mockConcurrencyCache
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
return
true
,
nil
}
}
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment