Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6bccb8a8
Unverified
Commit
6bccb8a8
authored
Feb 24, 2026
by
Wesley Liddick
Committed by
GitHub
Feb 24, 2026
Browse files
Merge branch 'main' into feature/antigravity-user-agent-configurable
parents
1fc6ef3d
3de1e0e4
Changes
270
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
270 of 270+
files are displayed.
Plain diff
Email patch
backend/internal/service/api_key_service_touch_last_used_test.go
0 → 100644
View file @
6bccb8a8
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func
TestAPIKeyService_TouchLastUsed_InvalidKeyID
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
return
errors
.
New
(
"should not be called"
)
},
}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
}
require
.
NoError
(
t
,
svc
.
TouchLastUsed
(
context
.
Background
(),
0
))
require
.
NoError
(
t
,
svc
.
TouchLastUsed
(
context
.
Background
(),
-
1
))
require
.
Empty
(
t
,
repo
.
touchedIDs
)
}
func
TestAPIKeyService_TouchLastUsed_FirstTouchSucceeds
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
}
err
:=
svc
.
TouchLastUsed
(
context
.
Background
(),
123
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
int64
{
123
},
repo
.
touchedIDs
)
require
.
Len
(
t
,
repo
.
touchedUsedAts
,
1
)
require
.
False
(
t
,
repo
.
touchedUsedAts
[
0
]
.
IsZero
())
cached
,
ok
:=
svc
.
lastUsedTouchL1
.
Load
(
int64
(
123
))
require
.
True
(
t
,
ok
,
"successful touch should update debounce cache"
)
_
,
isTime
:=
cached
.
(
time
.
Time
)
require
.
True
(
t
,
isTime
)
}
func
TestAPIKeyService_TouchLastUsed_DebouncedWithinWindow
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
}
require
.
NoError
(
t
,
svc
.
TouchLastUsed
(
context
.
Background
(),
123
))
require
.
NoError
(
t
,
svc
.
TouchLastUsed
(
context
.
Background
(),
123
))
require
.
Equal
(
t
,
[]
int64
{
123
},
repo
.
touchedIDs
,
"second touch within debounce window should not hit repository"
)
}
func
TestAPIKeyService_TouchLastUsed_ExpiredDebounceTouchesAgain
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
}
require
.
NoError
(
t
,
svc
.
TouchLastUsed
(
context
.
Background
(),
123
))
// 强制将 debounce 时间回拨到窗口之外,触发第二次写库。
svc
.
lastUsedTouchL1
.
Store
(
int64
(
123
),
time
.
Now
()
.
Add
(
-
apiKeyLastUsedMinTouch
-
time
.
Second
))
require
.
NoError
(
t
,
svc
.
TouchLastUsed
(
context
.
Background
(),
123
))
require
.
Len
(
t
,
repo
.
touchedIDs
,
2
)
require
.
Equal
(
t
,
int64
(
123
),
repo
.
touchedIDs
[
0
])
require
.
Equal
(
t
,
int64
(
123
),
repo
.
touchedIDs
[
1
])
}
func
TestAPIKeyService_TouchLastUsed_RepoError
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
return
errors
.
New
(
"db write failed"
)
},
}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
}
err
:=
svc
.
TouchLastUsed
(
context
.
Background
(),
123
)
require
.
Error
(
t
,
err
)
require
.
ErrorContains
(
t
,
err
,
"touch api key last used"
)
require
.
Equal
(
t
,
[]
int64
{
123
},
repo
.
touchedIDs
)
cached
,
ok
:=
svc
.
lastUsedTouchL1
.
Load
(
int64
(
123
))
require
.
True
(
t
,
ok
,
"failed touch should still update retry debounce cache"
)
_
,
isTime
:=
cached
.
(
time
.
Time
)
require
.
True
(
t
,
isTime
)
}
func
TestAPIKeyService_TouchLastUsed_RepoErrorDebounced
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{
updateLastUsed
:
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
return
errors
.
New
(
"db write failed"
)
},
}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
}
firstErr
:=
svc
.
TouchLastUsed
(
context
.
Background
(),
456
)
require
.
Error
(
t
,
firstErr
)
require
.
ErrorContains
(
t
,
firstErr
,
"touch api key last used"
)
secondErr
:=
svc
.
TouchLastUsed
(
context
.
Background
(),
456
)
require
.
NoError
(
t
,
secondErr
,
"failed touch should be debounced and skip immediate retry"
)
require
.
Equal
(
t
,
[]
int64
{
456
},
repo
.
touchedIDs
,
"debounced retry should not hit repository again"
)
}
type
touchSingleflightRepo
struct
{
*
apiKeyRepoStub
mu
sync
.
Mutex
calls
int
blockCh
chan
struct
{}
}
func
(
r
*
touchSingleflightRepo
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
{
r
.
mu
.
Lock
()
r
.
calls
++
r
.
mu
.
Unlock
()
<-
r
.
blockCh
return
nil
}
func
TestAPIKeyService_TouchLastUsed_ConcurrentFirstTouchDeduplicated
(
t
*
testing
.
T
)
{
repo
:=
&
touchSingleflightRepo
{
apiKeyRepoStub
:
&
apiKeyRepoStub
{},
blockCh
:
make
(
chan
struct
{}),
}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
}
const
workers
=
20
startCh
:=
make
(
chan
struct
{})
errCh
:=
make
(
chan
error
,
workers
)
var
wg
sync
.
WaitGroup
for
i
:=
0
;
i
<
workers
;
i
++
{
wg
.
Add
(
1
)
go
func
()
{
defer
wg
.
Done
()
<-
startCh
errCh
<-
svc
.
TouchLastUsed
(
context
.
Background
(),
321
)
}()
}
close
(
startCh
)
require
.
Eventually
(
t
,
func
()
bool
{
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
return
repo
.
calls
>=
1
},
time
.
Second
,
10
*
time
.
Millisecond
)
close
(
repo
.
blockCh
)
wg
.
Wait
()
close
(
errCh
)
for
err
:=
range
errCh
{
require
.
NoError
(
t
,
err
)
}
repo
.
mu
.
Lock
()
defer
repo
.
mu
.
Unlock
()
require
.
Equal
(
t
,
1
,
repo
.
calls
,
"并发首次 touch 只应写库一次"
)
}
backend/internal/service/auth_service.go
View file @
6bccb8a8
...
@@ -7,13 +7,13 @@ import (
...
@@ -7,13 +7,13 @@ import (
"encoding/hex"
"encoding/hex"
"errors"
"errors"
"fmt"
"fmt"
"log"
"net/mail"
"net/mail"
"strings"
"strings"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/golang-jwt/jwt/v5"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/bcrypt"
...
@@ -118,12 +118,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -118,12 +118,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
// 验证邀请码
// 验证邀请码
redeemCode
,
err
:=
s
.
redeemRepo
.
GetByCode
(
ctx
,
invitationCode
)
redeemCode
,
err
:=
s
.
redeemRepo
.
GetByCode
(
ctx
,
invitationCode
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Invalid invitation code: %s, error: %v"
,
invitationCode
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Invalid invitation code: %s, error: %v"
,
invitationCode
,
err
)
return
""
,
nil
,
ErrInvitationCodeInvalid
return
""
,
nil
,
ErrInvitationCodeInvalid
}
}
// 检查类型和状态
// 检查类型和状态
if
redeemCode
.
Type
!=
RedeemTypeInvitation
||
redeemCode
.
Status
!=
StatusUnused
{
if
redeemCode
.
Type
!=
RedeemTypeInvitation
||
redeemCode
.
Status
!=
StatusUnused
{
log
.
Printf
(
"[Auth] Invitation code invalid: type=%s, status=%s"
,
redeemCode
.
Type
,
redeemCode
.
Status
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Invitation code invalid: type=%s, status=%s"
,
redeemCode
.
Type
,
redeemCode
.
Status
)
return
""
,
nil
,
ErrInvitationCodeInvalid
return
""
,
nil
,
ErrInvitationCodeInvalid
}
}
invitationRedeemCode
=
redeemCode
invitationRedeemCode
=
redeemCode
...
@@ -134,7 +134,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -134,7 +134,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
// 这是一个配置错误,不应该允许绕过验证
// 这是一个配置错误,不应该允许绕过验证
if
s
.
emailService
==
nil
{
if
s
.
emailService
==
nil
{
log
.
Println
(
"[Auth] Email verification enabled but email service not configured, rejecting registration"
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"%s"
,
"[Auth] Email verification enabled but email service not configured, rejecting registration"
)
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
if
verifyCode
==
""
{
if
verifyCode
==
""
{
...
@@ -149,7 +149,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -149,7 +149,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
// 检查邮箱是否已存在
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error checking email exists: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error checking email exists: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
if
existsEmail
{
if
existsEmail
{
...
@@ -185,7 +185,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -185,7 +185,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
if
errors
.
Is
(
err
,
ErrEmailExists
)
{
if
errors
.
Is
(
err
,
ErrEmailExists
)
{
return
""
,
nil
,
ErrEmailExists
return
""
,
nil
,
ErrEmailExists
}
}
log
.
Printf
(
"[Auth] Database error creating user: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error creating user: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
...
@@ -193,14 +193,14 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -193,14 +193,14 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
if
invitationRedeemCode
!=
nil
{
if
invitationRedeemCode
!=
nil
{
if
err
:=
s
.
redeemRepo
.
Use
(
ctx
,
invitationRedeemCode
.
ID
,
user
.
ID
);
err
!=
nil
{
if
err
:=
s
.
redeemRepo
.
Use
(
ctx
,
invitationRedeemCode
.
ID
,
user
.
ID
);
err
!=
nil
{
// 邀请码标记失败不影响注册,只记录日志
// 邀请码标记失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to mark invitation code as used for user %d: %v"
,
user
.
ID
,
err
)
}
}
}
}
// 应用优惠码(如果提供且功能已启用)
// 应用优惠码(如果提供且功能已启用)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
// 优惠码应用失败不影响注册,只记录日志
// 优惠码应用失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to apply promo code for user %d: %v"
,
user
.
ID
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to apply promo code for user %d: %v"
,
user
.
ID
,
err
)
}
else
{
}
else
{
// 重新获取用户信息以获取更新后的余额
// 重新获取用户信息以获取更新后的余额
if
updatedUser
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
user
.
ID
);
err
==
nil
{
if
updatedUser
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
user
.
ID
);
err
==
nil
{
...
@@ -237,7 +237,7 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
...
@@ -237,7 +237,7 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查邮箱是否已存在
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error checking email exists: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error checking email exists: %v"
,
err
)
return
ErrServiceUnavailable
return
ErrServiceUnavailable
}
}
if
existsEmail
{
if
existsEmail
{
...
@@ -260,11 +260,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
...
@@ -260,11 +260,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
func
(
s
*
AuthService
)
SendVerifyCodeAsync
(
ctx
context
.
Context
,
email
string
)
(
*
SendVerifyCodeResult
,
error
)
{
func
(
s
*
AuthService
)
SendVerifyCodeAsync
(
ctx
context
.
Context
,
email
string
)
(
*
SendVerifyCodeResult
,
error
)
{
log
.
Printf
(
"[Auth] SendVerifyCodeAsync called for email: %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] SendVerifyCodeAsync called for email: %s"
,
email
)
// 检查是否开放注册(默认关闭)
// 检查是否开放注册(默认关闭)
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
log
.
Println
(
"[Auth] Registration is disabled"
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"%s"
,
"[Auth] Registration is disabled"
)
return
nil
,
ErrRegDisabled
return
nil
,
ErrRegDisabled
}
}
...
@@ -275,17 +275,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
...
@@ -275,17 +275,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// 检查邮箱是否已存在
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error checking email exists: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error checking email exists: %v"
,
err
)
return
nil
,
ErrServiceUnavailable
return
nil
,
ErrServiceUnavailable
}
}
if
existsEmail
{
if
existsEmail
{
log
.
Printf
(
"[Auth] Email already exists: %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Email already exists: %s"
,
email
)
return
nil
,
ErrEmailExists
return
nil
,
ErrEmailExists
}
}
// 检查邮件队列服务是否配置
// 检查邮件队列服务是否配置
if
s
.
emailQueueService
==
nil
{
if
s
.
emailQueueService
==
nil
{
log
.
Println
(
"[Auth] Email queue service not configured"
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"%s"
,
"[Auth] Email queue service not configured"
)
return
nil
,
errors
.
New
(
"email queue service not configured"
)
return
nil
,
errors
.
New
(
"email queue service not configured"
)
}
}
...
@@ -296,13 +296,13 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
...
@@ -296,13 +296,13 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
}
}
// 异步发送
// 异步发送
log
.
Printf
(
"[Auth] Enqueueing verify code for: %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Enqueueing verify code for: %s"
,
email
)
if
err
:=
s
.
emailQueueService
.
EnqueueVerifyCode
(
email
,
siteName
);
err
!=
nil
{
if
err
:=
s
.
emailQueueService
.
EnqueueVerifyCode
(
email
,
siteName
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to enqueue: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to enqueue: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"enqueue verify code: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"enqueue verify code: %w"
,
err
)
}
}
log
.
Printf
(
"[Auth] Verify code enqueued successfully for: %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Verify code enqueued successfully for: %s"
,
email
)
return
&
SendVerifyCodeResult
{
return
&
SendVerifyCodeResult
{
Countdown
:
60
,
// 60秒倒计时
Countdown
:
60
,
// 60秒倒计时
},
nil
},
nil
...
@@ -314,27 +314,27 @@ func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteI
...
@@ -314,27 +314,27 @@ func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteI
if
required
{
if
required
{
if
s
.
settingService
==
nil
{
if
s
.
settingService
==
nil
{
log
.
Println
(
"[Auth] Turnstile required but settings service is not configured"
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"%s"
,
"[Auth] Turnstile required but settings service is not configured"
)
return
ErrTurnstileNotConfigured
return
ErrTurnstileNotConfigured
}
}
enabled
:=
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
enabled
:=
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
secretConfigured
:=
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
!=
""
secretConfigured
:=
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
!=
""
if
!
enabled
||
!
secretConfigured
{
if
!
enabled
||
!
secretConfigured
{
log
.
Printf
(
"[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)"
,
enabled
,
secretConfigured
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)"
,
enabled
,
secretConfigured
)
return
ErrTurnstileNotConfigured
return
ErrTurnstileNotConfigured
}
}
}
}
if
s
.
turnstileService
==
nil
{
if
s
.
turnstileService
==
nil
{
if
required
{
if
required
{
log
.
Println
(
"[Auth] Turnstile required but service not configured"
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"%s"
,
"[Auth] Turnstile required but service not configured"
)
return
ErrTurnstileNotConfigured
return
ErrTurnstileNotConfigured
}
}
return
nil
// 服务未配置则跳过验证
return
nil
// 服务未配置则跳过验证
}
}
if
!
required
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
&&
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
==
""
{
if
!
required
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
&&
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
==
""
{
log
.
Println
(
"[Auth] Turnstile enabled but secret key not configured"
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"%s"
,
"[Auth] Turnstile enabled but secret key not configured"
)
}
}
return
s
.
turnstileService
.
VerifyToken
(
ctx
,
token
,
remoteIP
)
return
s
.
turnstileService
.
VerifyToken
(
ctx
,
token
,
remoteIP
)
...
@@ -373,7 +373,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
...
@@ -373,7 +373,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
return
""
,
nil
,
ErrInvalidCredentials
return
""
,
nil
,
ErrInvalidCredentials
}
}
// 记录数据库错误但不暴露给用户
// 记录数据库错误但不暴露给用户
log
.
Printf
(
"[Auth] Database error during login: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error during login: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
...
@@ -426,7 +426,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
...
@@ -426,7 +426,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
randomPassword
,
err
:=
randomHexString
(
32
)
randomPassword
,
err
:=
randomHexString
(
32
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to generate random password for oauth signup: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to generate random password for oauth signup: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
hashedPassword
,
err
:=
s
.
HashPassword
(
randomPassword
)
hashedPassword
,
err
:=
s
.
HashPassword
(
randomPassword
)
...
@@ -457,18 +457,18 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
...
@@ -457,18 +457,18 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// 并发场景:GetByEmail 与 Create 之间用户被创建。
// 并发场景:GetByEmail 与 Create 之间用户被创建。
user
,
err
=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
user
,
err
=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error getting user after conflict: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error getting user after conflict: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
}
else
{
}
else
{
log
.
Printf
(
"[Auth] Database error creating oauth user: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error creating oauth user: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
}
else
{
}
else
{
user
=
newUser
user
=
newUser
}
}
}
else
{
}
else
{
log
.
Printf
(
"[Auth] Database error during oauth login: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error during oauth login: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
return
""
,
nil
,
ErrServiceUnavailable
}
}
}
}
...
@@ -481,7 +481,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
...
@@ -481,7 +481,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
if
user
.
Username
==
""
&&
username
!=
""
{
if
user
.
Username
==
""
&&
username
!=
""
{
user
.
Username
=
username
user
.
Username
=
username
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to update username after oauth login: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to update username after oauth login: %v"
,
err
)
}
}
}
}
...
@@ -523,7 +523,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
...
@@ -523,7 +523,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
randomPassword
,
err
:=
randomHexString
(
32
)
randomPassword
,
err
:=
randomHexString
(
32
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to generate random password for oauth signup: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to generate random password for oauth signup: %v"
,
err
)
return
nil
,
nil
,
ErrServiceUnavailable
return
nil
,
nil
,
ErrServiceUnavailable
}
}
hashedPassword
,
err
:=
s
.
HashPassword
(
randomPassword
)
hashedPassword
,
err
:=
s
.
HashPassword
(
randomPassword
)
...
@@ -552,18 +552,18 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
...
@@ -552,18 +552,18 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
if
errors
.
Is
(
err
,
ErrEmailExists
)
{
if
errors
.
Is
(
err
,
ErrEmailExists
)
{
user
,
err
=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
user
,
err
=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error getting user after conflict: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error getting user after conflict: %v"
,
err
)
return
nil
,
nil
,
ErrServiceUnavailable
return
nil
,
nil
,
ErrServiceUnavailable
}
}
}
else
{
}
else
{
log
.
Printf
(
"[Auth] Database error creating oauth user: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error creating oauth user: %v"
,
err
)
return
nil
,
nil
,
ErrServiceUnavailable
return
nil
,
nil
,
ErrServiceUnavailable
}
}
}
else
{
}
else
{
user
=
newUser
user
=
newUser
}
}
}
else
{
}
else
{
log
.
Printf
(
"[Auth] Database error during oauth login: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error during oauth login: %v"
,
err
)
return
nil
,
nil
,
ErrServiceUnavailable
return
nil
,
nil
,
ErrServiceUnavailable
}
}
}
}
...
@@ -575,7 +575,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
...
@@ -575,7 +575,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
if
user
.
Username
==
""
&&
username
!=
""
{
if
user
.
Username
==
""
&&
username
!=
""
{
user
.
Username
=
username
user
.
Username
=
username
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to update username after oauth login: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to update username after oauth login: %v"
,
err
)
}
}
}
}
...
@@ -715,7 +715,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
...
@@ -715,7 +715,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
return
""
,
ErrInvalidToken
return
""
,
ErrInvalidToken
}
}
log
.
Printf
(
"[Auth] Database error refreshing token: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error refreshing token: %v"
,
err
)
return
""
,
ErrServiceUnavailable
return
""
,
ErrServiceUnavailable
}
}
...
@@ -756,16 +756,16 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB
...
@@ -756,16 +756,16 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
// Security: Log but don't reveal that user doesn't exist
// Security: Log but don't reveal that user doesn't exist
log
.
Printf
(
"[Auth] Password reset requested for non-existent email: %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Password reset requested for non-existent email: %s"
,
email
)
return
""
,
""
,
false
return
""
,
""
,
false
}
}
log
.
Printf
(
"[Auth] Database error checking email for password reset: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error checking email for password reset: %v"
,
err
)
return
""
,
""
,
false
return
""
,
""
,
false
}
}
// Check if user is active
// Check if user is active
if
!
user
.
IsActive
()
{
if
!
user
.
IsActive
()
{
log
.
Printf
(
"[Auth] Password reset requested for inactive user: %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Password reset requested for inactive user: %s"
,
email
)
return
""
,
""
,
false
return
""
,
""
,
false
}
}
...
@@ -797,11 +797,11 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB
...
@@ -797,11 +797,11 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB
}
}
if
err
:=
s
.
emailService
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
if
err
:=
s
.
emailService
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to send password reset email to %s: %v"
,
email
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to send password reset email to %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
return
nil
// Silent success to prevent enumeration
}
}
log
.
Printf
(
"[Auth] Password reset email sent to: %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Password reset email sent to: %s"
,
email
)
return
nil
return
nil
}
}
...
@@ -821,11 +821,11 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron
...
@@ -821,11 +821,11 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron
}
}
if
err
:=
s
.
emailQueueService
.
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
);
err
!=
nil
{
if
err
:=
s
.
emailQueueService
.
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to enqueue password reset email for %s: %v"
,
email
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to enqueue password reset email for %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
return
nil
// Silent success to prevent enumeration
}
}
log
.
Printf
(
"[Auth] Password reset email enqueued for: %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Password reset email enqueued for: %s"
,
email
)
return
nil
return
nil
}
}
...
@@ -852,7 +852,7 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
...
@@ -852,7 +852,7 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
return
ErrInvalidResetToken
// Token was valid but user was deleted
return
ErrInvalidResetToken
// Token was valid but user was deleted
}
}
log
.
Printf
(
"[Auth] Database error getting user for password reset: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error getting user for password reset: %v"
,
err
)
return
ErrServiceUnavailable
return
ErrServiceUnavailable
}
}
...
@@ -872,17 +872,17 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
...
@@ -872,17 +872,17 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
user
.
TokenVersion
++
// Invalidate all existing tokens
user
.
TokenVersion
++
// Invalidate all existing tokens
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error updating password for user %d: %v"
,
user
.
ID
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error updating password for user %d: %v"
,
user
.
ID
,
err
)
return
ErrServiceUnavailable
return
ErrServiceUnavailable
}
}
// Also revoke all refresh tokens for this user
// Also revoke all refresh tokens for this user
if
err
:=
s
.
RevokeAllUserSessions
(
ctx
,
user
.
ID
);
err
!=
nil
{
if
err
:=
s
.
RevokeAllUserSessions
(
ctx
,
user
.
ID
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to revoke refresh tokens for user %d: %v"
,
user
.
ID
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to revoke refresh tokens for user %d: %v"
,
user
.
ID
,
err
)
// Don't return error - password was already changed successfully
// Don't return error - password was already changed successfully
}
}
log
.
Printf
(
"[Auth] Password reset successful for user: %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Password reset successful for user: %s"
,
email
)
return
nil
return
nil
}
}
...
@@ -961,13 +961,13 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
...
@@ -961,13 +961,13 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
// 添加到用户Token集合
// 添加到用户Token集合
if
err
:=
s
.
refreshTokenCache
.
AddToUserTokenSet
(
ctx
,
user
.
ID
,
tokenHash
,
ttl
);
err
!=
nil
{
if
err
:=
s
.
refreshTokenCache
.
AddToUserTokenSet
(
ctx
,
user
.
ID
,
tokenHash
,
ttl
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to add token to user set: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to add token to user set: %v"
,
err
)
// 不影响主流程
// 不影响主流程
}
}
// 添加到家族Token集合
// 添加到家族Token集合
if
err
:=
s
.
refreshTokenCache
.
AddToFamilyTokenSet
(
ctx
,
familyID
,
tokenHash
,
ttl
);
err
!=
nil
{
if
err
:=
s
.
refreshTokenCache
.
AddToFamilyTokenSet
(
ctx
,
familyID
,
tokenHash
,
ttl
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to add token to family set: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to add token to family set: %v"
,
err
)
// 不影响主流程
// 不影响主流程
}
}
...
@@ -994,10 +994,10 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
...
@@ -994,10 +994,10 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrRefreshTokenNotFound
)
{
if
errors
.
Is
(
err
,
ErrRefreshTokenNotFound
)
{
// Token不存在,可能是已被使用(Token轮转)或已过期
// Token不存在,可能是已被使用(Token轮转)或已过期
log
.
Printf
(
"[Auth] Refresh token not found, possible reuse attack"
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Refresh token not found, possible reuse attack"
)
return
nil
,
ErrRefreshTokenInvalid
return
nil
,
ErrRefreshTokenInvalid
}
}
log
.
Printf
(
"[Auth] Error getting refresh token: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Error getting refresh token: %v"
,
err
)
return
nil
,
ErrServiceUnavailable
return
nil
,
ErrServiceUnavailable
}
}
...
@@ -1016,7 +1016,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
...
@@ -1016,7 +1016,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
_
=
s
.
refreshTokenCache
.
DeleteTokenFamily
(
ctx
,
data
.
FamilyID
)
_
=
s
.
refreshTokenCache
.
DeleteTokenFamily
(
ctx
,
data
.
FamilyID
)
return
nil
,
ErrRefreshTokenInvalid
return
nil
,
ErrRefreshTokenInvalid
}
}
log
.
Printf
(
"[Auth] Database error getting user for token refresh: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Database error getting user for token refresh: %v"
,
err
)
return
nil
,
ErrServiceUnavailable
return
nil
,
ErrServiceUnavailable
}
}
...
@@ -1036,7 +1036,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
...
@@ -1036,7 +1036,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
// Token轮转:立即使旧Token失效
// Token轮转:立即使旧Token失效
if
err
:=
s
.
refreshTokenCache
.
DeleteRefreshToken
(
ctx
,
tokenHash
);
err
!=
nil
{
if
err
:=
s
.
refreshTokenCache
.
DeleteRefreshToken
(
ctx
,
tokenHash
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to delete old refresh token: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to delete old refresh token: %v"
,
err
)
// 继续处理,不影响主流程
// 继续处理,不影响主流程
}
}
...
...
backend/internal/service/auth_service_register_test.go
View file @
6bccb8a8
...
@@ -315,3 +315,69 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
...
@@ -315,3 +315,69 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
require
.
NotEmpty
(
t
,
newToken
)
require
.
NotEmpty
(
t
,
newToken
)
})
})
}
}
func
TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour
(
t
*
testing
.
T
)
{
service
:=
newAuthService
(
&
userRepoStub
{},
nil
,
nil
)
service
.
cfg
.
JWT
.
ExpireHour
=
24
service
.
cfg
.
JWT
.
AccessTokenExpireMinutes
=
0
require
.
Equal
(
t
,
24
*
3600
,
service
.
GetAccessTokenExpiresIn
())
}
func
TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority
(
t
*
testing
.
T
)
{
service
:=
newAuthService
(
&
userRepoStub
{},
nil
,
nil
)
service
.
cfg
.
JWT
.
ExpireHour
=
24
service
.
cfg
.
JWT
.
AccessTokenExpireMinutes
=
90
require
.
Equal
(
t
,
90
*
60
,
service
.
GetAccessTokenExpiresIn
())
}
func
TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero
(
t
*
testing
.
T
)
{
service
:=
newAuthService
(
&
userRepoStub
{},
nil
,
nil
)
service
.
cfg
.
JWT
.
ExpireHour
=
24
service
.
cfg
.
JWT
.
AccessTokenExpireMinutes
=
0
user
:=
&
User
{
ID
:
1
,
Email
:
"test@test.com"
,
Role
:
RoleUser
,
Status
:
StatusActive
,
TokenVersion
:
1
,
}
token
,
err
:=
service
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
claims
,
err
:=
service
.
ValidateToken
(
token
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
claims
)
require
.
NotNil
(
t
,
claims
.
IssuedAt
)
require
.
NotNil
(
t
,
claims
.
ExpiresAt
)
require
.
WithinDuration
(
t
,
claims
.
IssuedAt
.
Time
.
Add
(
24
*
time
.
Hour
),
claims
.
ExpiresAt
.
Time
,
2
*
time
.
Second
)
}
func
TestAuthService_GenerateToken_UsesMinutesWhenConfigured
(
t
*
testing
.
T
)
{
service
:=
newAuthService
(
&
userRepoStub
{},
nil
,
nil
)
service
.
cfg
.
JWT
.
ExpireHour
=
24
service
.
cfg
.
JWT
.
AccessTokenExpireMinutes
=
90
user
:=
&
User
{
ID
:
2
,
Email
:
"test2@test.com"
,
Role
:
RoleUser
,
Status
:
StatusActive
,
TokenVersion
:
1
,
}
token
,
err
:=
service
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
claims
,
err
:=
service
.
ValidateToken
(
token
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
claims
)
require
.
NotNil
(
t
,
claims
.
IssuedAt
)
require
.
NotNil
(
t
,
claims
.
ExpiresAt
)
require
.
WithinDuration
(
t
,
claims
.
IssuedAt
.
Time
.
Add
(
90
*
time
.
Minute
),
claims
.
ExpiresAt
.
Time
,
2
*
time
.
Second
)
}
backend/internal/service/billing_cache_service.go
View file @
6bccb8a8
...
@@ -3,13 +3,13 @@ package service
...
@@ -3,13 +3,13 @@ package service
import
(
import
(
"context"
"context"
"fmt"
"fmt"
"log"
"sync"
"sync"
"sync/atomic"
"sync/atomic"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
)
// 错误定义
// 错误定义
...
@@ -156,13 +156,13 @@ func (s *BillingCacheService) cacheWriteWorker() {
...
@@ -156,13 +156,13 @@ func (s *BillingCacheService) cacheWriteWorker() {
case
cacheWriteUpdateSubscriptionUsage
:
case
cacheWriteUpdateSubscriptionUsage
:
if
s
.
cache
!=
nil
{
if
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
UpdateSubscriptionUsage
(
ctx
,
task
.
userID
,
task
.
groupID
,
task
.
amount
);
err
!=
nil
{
if
err
:=
s
.
cache
.
UpdateSubscriptionUsage
(
ctx
,
task
.
userID
,
task
.
groupID
,
task
.
amount
);
err
!=
nil
{
log
.
Printf
(
"Warning: update subscription cache failed for user %d group %d: %v"
,
task
.
userID
,
task
.
groupID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: update subscription cache failed for user %d group %d: %v"
,
task
.
userID
,
task
.
groupID
,
err
)
}
}
}
}
case
cacheWriteDeductBalance
:
case
cacheWriteDeductBalance
:
if
s
.
cache
!=
nil
{
if
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
DeductUserBalance
(
ctx
,
task
.
userID
,
task
.
amount
);
err
!=
nil
{
if
err
:=
s
.
cache
.
DeductUserBalance
(
ctx
,
task
.
userID
,
task
.
amount
);
err
!=
nil
{
log
.
Printf
(
"Warning: deduct balance cache failed for user %d: %v"
,
task
.
userID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: deduct balance cache failed for user %d: %v"
,
task
.
userID
,
err
)
}
}
}
}
}
}
...
@@ -216,7 +216,7 @@ func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason stri
...
@@ -216,7 +216,7 @@ func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason stri
if
dropped
==
0
{
if
dropped
==
0
{
return
return
}
}
log
.
Printf
(
"Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)"
,
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)"
,
reason
,
reason
,
dropped
,
dropped
,
cacheWriteDropLogInterval
,
cacheWriteDropLogInterval
,
...
@@ -274,7 +274,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
...
@@ -274,7 +274,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
return
return
}
}
if
err
:=
s
.
cache
.
SetUserBalance
(
ctx
,
userID
,
balance
);
err
!=
nil
{
if
err
:=
s
.
cache
.
SetUserBalance
(
ctx
,
userID
,
balance
);
err
!=
nil
{
log
.
Printf
(
"Warning: set balance cache failed for user %d: %v"
,
userID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: set balance cache failed for user %d: %v"
,
userID
,
err
)
}
}
}
}
...
@@ -302,7 +302,7 @@ func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
...
@@ -302,7 +302,7 @@ func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
cacheWriteTimeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
cacheWriteTimeout
)
defer
cancel
()
defer
cancel
()
if
err
:=
s
.
DeductBalanceCache
(
ctx
,
userID
,
amount
);
err
!=
nil
{
if
err
:=
s
.
DeductBalanceCache
(
ctx
,
userID
,
amount
);
err
!=
nil
{
log
.
Printf
(
"Warning: deduct balance cache fallback failed for user %d: %v"
,
userID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: deduct balance cache fallback failed for user %d: %v"
,
userID
,
err
)
}
}
}
}
...
@@ -312,7 +312,7 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
...
@@ -312,7 +312,7 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
return
nil
return
nil
}
}
if
err
:=
s
.
cache
.
InvalidateUserBalance
(
ctx
,
userID
);
err
!=
nil
{
if
err
:=
s
.
cache
.
InvalidateUserBalance
(
ctx
,
userID
);
err
!=
nil
{
log
.
Printf
(
"Warning: invalidate balance cache failed for user %d: %v"
,
userID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: invalidate balance cache failed for user %d: %v"
,
userID
,
err
)
return
err
return
err
}
}
return
nil
return
nil
...
@@ -396,7 +396,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
...
@@ -396,7 +396,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
return
return
}
}
if
err
:=
s
.
cache
.
SetSubscriptionCache
(
ctx
,
userID
,
groupID
,
s
.
convertToPortsData
(
data
));
err
!=
nil
{
if
err
:=
s
.
cache
.
SetSubscriptionCache
(
ctx
,
userID
,
groupID
,
s
.
convertToPortsData
(
data
));
err
!=
nil
{
log
.
Printf
(
"Warning: set subscription cache failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: set subscription cache failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
}
}
}
}
...
@@ -425,7 +425,7 @@ func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64
...
@@ -425,7 +425,7 @@ func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
cacheWriteTimeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
cacheWriteTimeout
)
defer
cancel
()
defer
cancel
()
if
err
:=
s
.
UpdateSubscriptionUsage
(
ctx
,
userID
,
groupID
,
costUSD
);
err
!=
nil
{
if
err
:=
s
.
UpdateSubscriptionUsage
(
ctx
,
userID
,
groupID
,
costUSD
);
err
!=
nil
{
log
.
Printf
(
"Warning: update subscription cache fallback failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: update subscription cache fallback failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
}
}
}
}
...
@@ -435,7 +435,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
...
@@ -435,7 +435,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
return
nil
return
nil
}
}
if
err
:=
s
.
cache
.
InvalidateSubscriptionCache
(
ctx
,
userID
,
groupID
);
err
!=
nil
{
if
err
:=
s
.
cache
.
InvalidateSubscriptionCache
(
ctx
,
userID
,
groupID
);
err
!=
nil
{
log
.
Printf
(
"Warning: invalidate subscription cache failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: invalidate subscription cache failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
return
err
return
err
}
}
return
nil
return
nil
...
@@ -474,7 +474,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI
...
@@ -474,7 +474,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI
if
s
.
circuitBreaker
!=
nil
{
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnFailure
(
err
)
s
.
circuitBreaker
.
OnFailure
(
err
)
}
}
log
.
Printf
(
"ALERT: billing balance check failed for user %d: %v"
,
userID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"ALERT: billing balance check failed for user %d: %v"
,
userID
,
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
}
}
if
s
.
circuitBreaker
!=
nil
{
if
s
.
circuitBreaker
!=
nil
{
...
@@ -496,7 +496,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
...
@@ -496,7 +496,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
if
s
.
circuitBreaker
!=
nil
{
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnFailure
(
err
)
s
.
circuitBreaker
.
OnFailure
(
err
)
}
}
log
.
Printf
(
"ALERT: billing subscription check failed for user %d group %d: %v"
,
userID
,
group
.
ID
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"ALERT: billing subscription check failed for user %d group %d: %v"
,
userID
,
group
.
ID
,
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
}
}
if
s
.
circuitBreaker
!=
nil
{
if
s
.
circuitBreaker
!=
nil
{
...
@@ -585,7 +585,7 @@ func (b *billingCircuitBreaker) Allow() bool {
...
@@ -585,7 +585,7 @@ func (b *billingCircuitBreaker) Allow() bool {
}
}
b
.
state
=
billingCircuitHalfOpen
b
.
state
=
billingCircuitHalfOpen
b
.
halfOpenRemaining
=
b
.
halfOpenRequests
b
.
halfOpenRemaining
=
b
.
halfOpenRequests
log
.
Printf
(
"ALERT: billing circuit breaker entering half-open state"
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"ALERT: billing circuit breaker entering half-open state"
)
fallthrough
fallthrough
case
billingCircuitHalfOpen
:
case
billingCircuitHalfOpen
:
if
b
.
halfOpenRemaining
<=
0
{
if
b
.
halfOpenRemaining
<=
0
{
...
@@ -612,7 +612,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) {
...
@@ -612,7 +612,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) {
b
.
state
=
billingCircuitOpen
b
.
state
=
billingCircuitOpen
b
.
openedAt
=
time
.
Now
()
b
.
openedAt
=
time
.
Now
()
b
.
halfOpenRemaining
=
0
b
.
halfOpenRemaining
=
0
log
.
Printf
(
"ALERT: billing circuit breaker opened after half-open failure: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"ALERT: billing circuit breaker opened after half-open failure: %v"
,
err
)
return
return
default
:
default
:
b
.
failures
++
b
.
failures
++
...
@@ -620,7 +620,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) {
...
@@ -620,7 +620,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) {
b
.
state
=
billingCircuitOpen
b
.
state
=
billingCircuitOpen
b
.
openedAt
=
time
.
Now
()
b
.
openedAt
=
time
.
Now
()
b
.
halfOpenRemaining
=
0
b
.
halfOpenRemaining
=
0
log
.
Printf
(
"ALERT: billing circuit breaker opened after %d failures: %v"
,
b
.
failures
,
err
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"ALERT: billing circuit breaker opened after %d failures: %v"
,
b
.
failures
,
err
)
}
}
}
}
}
}
...
@@ -641,9 +641,9 @@ func (b *billingCircuitBreaker) OnSuccess() {
...
@@ -641,9 +641,9 @@ func (b *billingCircuitBreaker) OnSuccess() {
// 只有状态真正发生变化时才记录日志
// 只有状态真正发生变化时才记录日志
if
previousState
!=
billingCircuitClosed
{
if
previousState
!=
billingCircuitClosed
{
log
.
Printf
(
"ALERT: billing circuit breaker closed (was %s)"
,
circuitStateString
(
previousState
))
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"ALERT: billing circuit breaker closed (was %s)"
,
circuitStateString
(
previousState
))
}
else
if
previousFailures
>
0
{
}
else
if
previousFailures
>
0
{
log
.
Printf
(
"INFO: billing circuit breaker failures reset from %d"
,
previousFailures
)
log
ger
.
LegacyPrintf
(
"service.billing_cache"
,
"INFO: billing circuit breaker failures reset from %d"
,
previousFailures
)
}
}
}
}
...
...
backend/internal/service/billing_service.go
View file @
6bccb8a8
...
@@ -312,7 +312,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
...
@@ -312,7 +312,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
}
}
outRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
outRangeTokens
,
rateMultiplier
*
extraMultiplier
)
outRangeCost
,
err
:=
s
.
CalculateCost
(
model
,
outRangeTokens
,
rateMultiplier
*
extraMultiplier
)
if
err
!=
nil
{
if
err
!=
nil
{
return
inRangeCost
,
nil
// 出错时返回范围内成本
return
inRangeCost
,
fmt
.
Errorf
(
"out-range cost: %w"
,
err
)
}
}
// 合并成本
// 合并成本
...
@@ -388,6 +388,14 @@ type ImagePriceConfig struct {
...
@@ -388,6 +388,14 @@ type ImagePriceConfig struct {
Price4K
*
float64
// 4K 尺寸价格(nil 表示使用默认值)
Price4K
*
float64
// 4K 尺寸价格(nil 表示使用默认值)
}
}
// SoraPriceConfig Sora 按次计费配置
type
SoraPriceConfig
struct
{
ImagePrice360
*
float64
ImagePrice540
*
float64
VideoPricePerRequest
*
float64
VideoPricePerRequestHD
*
float64
}
// CalculateImageCost 计算图片生成费用
// CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K"
// imageSize: 图片尺寸 "1K", "2K", "4K"
...
@@ -417,6 +425,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
...
@@ -417,6 +425,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
}
}
}
}
// CalculateSoraImageCost 计算 Sora 图片按次费用
func
(
s
*
BillingService
)
CalculateSoraImageCost
(
imageSize
string
,
imageCount
int
,
groupConfig
*
SoraPriceConfig
,
rateMultiplier
float64
)
*
CostBreakdown
{
if
imageCount
<=
0
{
return
&
CostBreakdown
{}
}
unitPrice
:=
0.0
if
groupConfig
!=
nil
{
switch
imageSize
{
case
"540"
:
if
groupConfig
.
ImagePrice540
!=
nil
{
unitPrice
=
*
groupConfig
.
ImagePrice540
}
default
:
if
groupConfig
.
ImagePrice360
!=
nil
{
unitPrice
=
*
groupConfig
.
ImagePrice360
}
}
}
totalCost
:=
unitPrice
*
float64
(
imageCount
)
if
rateMultiplier
<=
0
{
rateMultiplier
=
1.0
}
actualCost
:=
totalCost
*
rateMultiplier
return
&
CostBreakdown
{
TotalCost
:
totalCost
,
ActualCost
:
actualCost
,
}
}
// CalculateSoraVideoCost 计算 Sora 视频按次费用
func
(
s
*
BillingService
)
CalculateSoraVideoCost
(
model
string
,
groupConfig
*
SoraPriceConfig
,
rateMultiplier
float64
)
*
CostBreakdown
{
unitPrice
:=
0.0
if
groupConfig
!=
nil
{
modelLower
:=
strings
.
ToLower
(
model
)
if
strings
.
Contains
(
modelLower
,
"sora2pro-hd"
)
{
if
groupConfig
.
VideoPricePerRequestHD
!=
nil
{
unitPrice
=
*
groupConfig
.
VideoPricePerRequestHD
}
}
if
unitPrice
<=
0
&&
groupConfig
.
VideoPricePerRequest
!=
nil
{
unitPrice
=
*
groupConfig
.
VideoPricePerRequest
}
}
totalCost
:=
unitPrice
if
rateMultiplier
<=
0
{
rateMultiplier
=
1.0
}
actualCost
:=
totalCost
*
rateMultiplier
return
&
CostBreakdown
{
TotalCost
:
totalCost
,
ActualCost
:
actualCost
,
}
}
// getImageUnitPrice 获取图片单价
// getImageUnitPrice 获取图片单价
func
(
s
*
BillingService
)
getImageUnitPrice
(
model
string
,
imageSize
string
,
groupConfig
*
ImagePriceConfig
)
float64
{
func
(
s
*
BillingService
)
getImageUnitPrice
(
model
string
,
imageSize
string
,
groupConfig
*
ImagePriceConfig
)
float64
{
// 优先使用分组配置的价格
// 优先使用分组配置的价格
...
...
backend/internal/service/billing_service_test.go
0 → 100644
View file @
6bccb8a8
//go:build unit
package
service
import
(
"math"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
newTestBillingService
()
*
BillingService
{
return
NewBillingService
(
&
config
.
Config
{},
nil
)
}
func
TestCalculateCost_BasicComputation
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
// 使用 claude-sonnet-4 的回退价格:Input $3/MTok, Output $15/MTok
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
,
}
cost
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
// 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075
expectedInput
:=
1000
*
3e-6
expectedOutput
:=
500
*
15e-6
require
.
InDelta
(
t
,
expectedInput
,
cost
.
InputCost
,
1e-10
)
require
.
InDelta
(
t
,
expectedOutput
,
cost
.
OutputCost
,
1e-10
)
require
.
InDelta
(
t
,
expectedInput
+
expectedOutput
,
cost
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
expectedInput
+
expectedOutput
,
cost
.
ActualCost
,
1e-10
)
}
func
TestCalculateCost_WithCacheTokens
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
,
CacheCreationTokens
:
2000
,
CacheReadTokens
:
3000
,
}
cost
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
expectedCacheCreation
:=
2000
*
3.75e-6
expectedCacheRead
:=
3000
*
0.3e-6
require
.
InDelta
(
t
,
expectedCacheCreation
,
cost
.
CacheCreationCost
,
1e-10
)
require
.
InDelta
(
t
,
expectedCacheRead
,
cost
.
CacheReadCost
,
1e-10
)
expectedTotal
:=
cost
.
InputCost
+
cost
.
OutputCost
+
expectedCacheCreation
+
expectedCacheRead
require
.
InDelta
(
t
,
expectedTotal
,
cost
.
TotalCost
,
1e-10
)
}
func
TestCalculateCost_RateMultiplier
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
}
cost1x
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
cost2x
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
2.0
)
require
.
NoError
(
t
,
err
)
// TotalCost 不受倍率影响,ActualCost 翻倍
require
.
InDelta
(
t
,
cost1x
.
TotalCost
,
cost2x
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
cost1x
.
ActualCost
*
2
,
cost2x
.
ActualCost
,
1e-10
)
}
func
TestCalculateCost_ZeroMultiplierDefaultsToOne
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
1000
}
costZero
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
0
)
require
.
NoError
(
t
,
err
)
costOne
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
costOne
.
ActualCost
,
costZero
.
ActualCost
,
1e-10
)
}
func
TestCalculateCost_NegativeMultiplierDefaultsToOne
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
1000
}
costNeg
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
-
1.0
)
require
.
NoError
(
t
,
err
)
costOne
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
costOne
.
ActualCost
,
costNeg
.
ActualCost
,
1e-10
)
}
func
TestGetModelPricing_FallbackMatchesByFamily
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tests
:=
[]
struct
{
model
string
expectedInput
float64
}{
{
"claude-opus-4.5-20250101"
,
5e-6
},
{
"claude-3-opus-20240229"
,
15e-6
},
{
"claude-sonnet-4-20250514"
,
3e-6
},
{
"claude-3-5-sonnet-20241022"
,
3e-6
},
{
"claude-3-5-haiku-20241022"
,
1e-6
},
{
"claude-3-haiku-20240307"
,
0.25e-6
},
}
for
_
,
tt
:=
range
tests
{
pricing
,
err
:=
svc
.
GetModelPricing
(
tt
.
model
)
require
.
NoError
(
t
,
err
,
"模型 %s"
,
tt
.
model
)
require
.
InDelta
(
t
,
tt
.
expectedInput
,
pricing
.
InputPricePerToken
,
1e-12
,
"模型 %s 输入价格"
,
tt
.
model
)
}
}
func
TestGetModelPricing_CaseInsensitive
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
p1
,
err
:=
svc
.
GetModelPricing
(
"Claude-Sonnet-4"
)
require
.
NoError
(
t
,
err
)
p2
,
err
:=
svc
.
GetModelPricing
(
"claude-sonnet-4"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
p1
.
InputPricePerToken
,
p2
.
InputPricePerToken
)
}
func
TestGetModelPricing_UnknownModelFallsBackToSonnet
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
// 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格
pricing
,
err
:=
svc
.
GetModelPricing
(
"claude-unknown-model"
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
3e-6
,
pricing
.
InputPricePerToken
,
1e-12
)
}
func
TestCalculateCostWithLongContext_BelowThreshold
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
50000
,
OutputTokens
:
1000
,
CacheReadTokens
:
100000
,
}
// 总输入 150k < 200k 阈值,应走正常计费
cost
,
err
:=
svc
.
CalculateCostWithLongContext
(
"claude-sonnet-4"
,
tokens
,
1.0
,
200000
,
2.0
)
require
.
NoError
(
t
,
err
)
normalCost
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
normalCost
.
ActualCost
,
cost
.
ActualCost
,
1e-10
)
}
func
TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
// 缓存 210k + 输入 10k = 220k > 200k 阈值
// 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入
tokens
:=
UsageTokens
{
InputTokens
:
10000
,
OutputTokens
:
1000
,
CacheReadTokens
:
210000
,
}
cost
,
err
:=
svc
.
CalculateCostWithLongContext
(
"claude-sonnet-4"
,
tokens
,
1.0
,
200000
,
2.0
)
require
.
NoError
(
t
,
err
)
// 范围内:200k cache + 0 input + 1k output
inRange
,
_
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
UsageTokens
{
InputTokens
:
0
,
OutputTokens
:
1000
,
CacheReadTokens
:
200000
,
},
1.0
)
// 范围外:10k cache + 10k input,倍率 2.0
outRange
,
_
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
UsageTokens
{
InputTokens
:
10000
,
CacheReadTokens
:
10000
,
},
2.0
)
require
.
InDelta
(
t
,
inRange
.
ActualCost
+
outRange
.
ActualCost
,
cost
.
ActualCost
,
1e-10
)
}
func
TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
// 缓存 100k + 输入 150k = 250k > 200k 阈值
// 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入
tokens
:=
UsageTokens
{
InputTokens
:
150000
,
OutputTokens
:
1000
,
CacheReadTokens
:
100000
,
}
cost
,
err
:=
svc
.
CalculateCostWithLongContext
(
"claude-sonnet-4"
,
tokens
,
1.0
,
200000
,
2.0
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
cost
.
ActualCost
>
0
,
"费用应大于 0"
)
// 正常费用不含长上下文
normalCost
,
_
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
True
(
t
,
cost
.
ActualCost
>
normalCost
.
ActualCost
,
"长上下文费用应高于正常费用"
)
}
func
TestCalculateCostWithLongContext_DisabledThreshold
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
300000
,
CacheReadTokens
:
0
}
// threshold <= 0 应禁用长上下文计费
cost1
,
err
:=
svc
.
CalculateCostWithLongContext
(
"claude-sonnet-4"
,
tokens
,
1.0
,
0
,
2.0
)
require
.
NoError
(
t
,
err
)
cost2
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
cost2
.
ActualCost
,
cost1
.
ActualCost
,
1e-10
)
}
func
TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
300000
}
// extraMultiplier <= 1 应禁用长上下文计费
cost
,
err
:=
svc
.
CalculateCostWithLongContext
(
"claude-sonnet-4"
,
tokens
,
1.0
,
200000
,
1.0
)
require
.
NoError
(
t
,
err
)
normalCost
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
require
.
InDelta
(
t
,
normalCost
.
ActualCost
,
cost
.
ActualCost
,
1e-10
)
}
func
TestCalculateImageCost
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
price
:=
0.134
cfg
:=
&
ImagePriceConfig
{
Price1K
:
&
price
}
cost
:=
svc
.
CalculateImageCost
(
"gpt-image-1"
,
"1K"
,
3
,
cfg
,
1.0
)
require
.
InDelta
(
t
,
0.134
*
3
,
cost
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
0.134
*
3
,
cost
.
ActualCost
,
1e-10
)
}
func
TestCalculateSoraVideoCost
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
price
:=
0.5
cfg
:=
&
SoraPriceConfig
{
VideoPricePerRequest
:
&
price
}
cost
:=
svc
.
CalculateSoraVideoCost
(
"sora-video"
,
cfg
,
1.0
)
require
.
InDelta
(
t
,
0.5
,
cost
.
TotalCost
,
1e-10
)
}
func
TestCalculateSoraVideoCost_HDModel
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
hdPrice
:=
1.0
normalPrice
:=
0.5
cfg
:=
&
SoraPriceConfig
{
VideoPricePerRequest
:
&
normalPrice
,
VideoPricePerRequestHD
:
&
hdPrice
,
}
cost
:=
svc
.
CalculateSoraVideoCost
(
"sora2pro-hd"
,
cfg
,
1.0
)
require
.
InDelta
(
t
,
1.0
,
cost
.
TotalCost
,
1e-10
)
}
func
TestIsModelSupported
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
require
.
True
(
t
,
svc
.
IsModelSupported
(
"claude-sonnet-4"
))
require
.
True
(
t
,
svc
.
IsModelSupported
(
"Claude-Opus-4.5"
))
require
.
True
(
t
,
svc
.
IsModelSupported
(
"claude-3-haiku"
))
require
.
False
(
t
,
svc
.
IsModelSupported
(
"gpt-4o"
))
require
.
False
(
t
,
svc
.
IsModelSupported
(
"gemini-pro"
))
}
func
TestCalculateCost_ZeroTokens
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
cost
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
UsageTokens
{},
1.0
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
0.0
,
cost
.
TotalCost
)
require
.
Equal
(
t
,
0.0
,
cost
.
ActualCost
)
}
func
TestCalculateCostWithConfig
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Default
.
RateMultiplier
=
1.5
svc
:=
NewBillingService
(
cfg
,
nil
)
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
}
cost
,
err
:=
svc
.
CalculateCostWithConfig
(
"claude-sonnet-4"
,
tokens
)
require
.
NoError
(
t
,
err
)
expected
,
_
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.5
)
require
.
InDelta
(
t
,
expected
.
ActualCost
,
cost
.
ActualCost
,
1e-10
)
}
func
TestCalculateCostWithConfig_ZeroMultiplier
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{}
cfg
.
Default
.
RateMultiplier
=
0
svc
:=
NewBillingService
(
cfg
,
nil
)
tokens
:=
UsageTokens
{
InputTokens
:
1000
}
cost
,
err
:=
svc
.
CalculateCostWithConfig
(
"claude-sonnet-4"
,
tokens
)
require
.
NoError
(
t
,
err
)
// 倍率 <=0 时默认 1.0
expected
,
_
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
InDelta
(
t
,
expected
.
ActualCost
,
cost
.
ActualCost
,
1e-10
)
}
func
TestGetEstimatedCost
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
est
,
err
:=
svc
.
GetEstimatedCost
(
"claude-sonnet-4"
,
1000
,
500
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
est
>
0
)
}
func
TestListSupportedModels
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
models
:=
svc
.
ListSupportedModels
()
require
.
NotEmpty
(
t
,
models
)
require
.
GreaterOrEqual
(
t
,
len
(
models
),
6
)
}
func
TestGetPricingServiceStatus_NilService
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
status
:=
svc
.
GetPricingServiceStatus
()
require
.
NotNil
(
t
,
status
)
require
.
Equal
(
t
,
"using fallback"
,
status
[
"last_updated"
])
}
func
TestForceUpdatePricing_NilService
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
err
:=
svc
.
ForceUpdatePricing
()
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not initialized"
)
}
func
TestCalculateSoraImageCost
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
price360
:=
0.05
price540
:=
0.08
cfg
:=
&
SoraPriceConfig
{
ImagePrice360
:
&
price360
,
ImagePrice540
:
&
price540
}
cost
:=
svc
.
CalculateSoraImageCost
(
"360"
,
2
,
cfg
,
1.0
)
require
.
InDelta
(
t
,
0.10
,
cost
.
TotalCost
,
1e-10
)
cost540
:=
svc
.
CalculateSoraImageCost
(
"540"
,
1
,
cfg
,
2.0
)
require
.
InDelta
(
t
,
0.08
,
cost540
.
TotalCost
,
1e-10
)
require
.
InDelta
(
t
,
0.16
,
cost540
.
ActualCost
,
1e-10
)
}
func
TestCalculateSoraImageCost_ZeroCount
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
cost
:=
svc
.
CalculateSoraImageCost
(
"360"
,
0
,
nil
,
1.0
)
require
.
Equal
(
t
,
0.0
,
cost
.
TotalCost
)
}
func
TestCalculateSoraVideoCost_NilConfig
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
cost
:=
svc
.
CalculateSoraVideoCost
(
"sora-video"
,
nil
,
1.0
)
require
.
Equal
(
t
,
0.0
,
cost
.
TotalCost
)
}
func
TestCalculateCostWithLongContext_PropagatesError
(
t
*
testing
.
T
)
{
// 使用空的 fallback prices 让 GetModelPricing 失败
svc
:=
&
BillingService
{
cfg
:
&
config
.
Config
{},
fallbackPrices
:
make
(
map
[
string
]
*
ModelPricing
),
}
tokens
:=
UsageTokens
{
InputTokens
:
300000
,
CacheReadTokens
:
0
}
_
,
err
:=
svc
.
CalculateCostWithLongContext
(
"unknown-model"
,
tokens
,
1.0
,
200000
,
2.0
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"pricing not found"
)
}
func
TestCalculateCost_SupportsCacheBreakdown
(
t
*
testing
.
T
)
{
svc
:=
&
BillingService
{
cfg
:
&
config
.
Config
{},
fallbackPrices
:
map
[
string
]
*
ModelPricing
{
"claude-sonnet-4"
:
{
InputPricePerToken
:
3e-6
,
OutputPricePerToken
:
15e-6
,
SupportsCacheBreakdown
:
true
,
CacheCreation5mPrice
:
4e-6
,
// per token
CacheCreation1hPrice
:
5e-6
,
// per token
},
},
}
tokens
:=
UsageTokens
{
InputTokens
:
1000
,
OutputTokens
:
500
,
CacheCreation5mTokens
:
100000
,
CacheCreation1hTokens
:
50000
,
}
cost
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
expected5m
:=
float64
(
tokens
.
CacheCreation5mTokens
)
*
4e-6
expected1h
:=
float64
(
tokens
.
CacheCreation1hTokens
)
*
5e-6
require
.
InDelta
(
t
,
expected5m
+
expected1h
,
cost
.
CacheCreationCost
,
1e-10
)
}
func
TestCalculateCost_LargeTokenCount
(
t
*
testing
.
T
)
{
svc
:=
newTestBillingService
()
tokens
:=
UsageTokens
{
InputTokens
:
1
_000_000
,
OutputTokens
:
1
_000_000
,
}
cost
,
err
:=
svc
.
CalculateCost
(
"claude-sonnet-4"
,
tokens
,
1.0
)
require
.
NoError
(
t
,
err
)
// Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15
require
.
InDelta
(
t
,
3.0
,
cost
.
InputCost
,
1e-6
)
require
.
InDelta
(
t
,
15.0
,
cost
.
OutputCost
,
1e-6
)
require
.
False
(
t
,
math
.
IsNaN
(
cost
.
TotalCost
))
require
.
False
(
t
,
math
.
IsInf
(
cost
.
TotalCost
,
0
))
}
backend/internal/service/claude_code_detection_test.go
0 → 100644
View file @
6bccb8a8
//go:build unit
package
service
import
(
"context"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func
newTestValidator
()
*
ClaudeCodeValidator
{
return
NewClaudeCodeValidator
()
}
// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体
func
validClaudeCodeBody
()
map
[
string
]
any
{
return
map
[
string
]
any
{
"model"
:
"claude-sonnet-4-20250514"
,
"system"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"You are Claude Code, Anthropic's official CLI for Claude."
,
},
},
"metadata"
:
map
[
string
]
any
{
"user_id"
:
"user_"
+
"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
+
"_account__session_"
+
"12345678-1234-1234-1234-123456789abc"
,
},
}
}
func
TestValidate_ClaudeCLIUserAgent
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
tests
:=
[]
struct
{
name
string
ua
string
want
bool
}{
{
"标准版本号"
,
"claude-cli/1.0.0"
,
true
},
{
"多位版本号"
,
"claude-cli/12.34.56"
,
true
},
{
"大写开头"
,
"Claude-CLI/1.0.0"
,
true
},
{
"非 claude-cli"
,
"curl/7.64.1"
,
false
},
{
"空 User-Agent"
,
""
,
false
},
{
"部分匹配"
,
"not-claude-cli/1.0.0"
,
false
},
{
"缺少版本号"
,
"claude-cli/"
,
false
},
{
"版本格式不对"
,
"claude-cli/1.0"
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
tt
.
want
,
v
.
ValidateUserAgent
(
tt
.
ua
),
"UA: %q"
,
tt
.
ua
)
})
}
}
func
TestValidate_NonMessagesPath_UAOnly
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
// 非 messages 路径只检查 UA
req
:=
httptest
.
NewRequest
(
"GET"
,
"/v1/models"
,
nil
)
req
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.0"
)
result
:=
v
.
Validate
(
req
,
nil
)
require
.
True
(
t
,
result
,
"非 messages 路径只需 UA 匹配"
)
}
func
TestValidate_NonMessagesPath_InvalidUA
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
req
:=
httptest
.
NewRequest
(
"GET"
,
"/v1/models"
,
nil
)
req
.
Header
.
Set
(
"User-Agent"
,
"curl/7.64.1"
)
result
:=
v
.
Validate
(
req
,
nil
)
require
.
False
(
t
,
result
,
"UA 不匹配时应返回 false"
)
}
func
TestValidate_MessagesPath_FullValid
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
req
:=
httptest
.
NewRequest
(
"POST"
,
"/v1/messages"
,
nil
)
req
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.0"
)
req
.
Header
.
Set
(
"X-App"
,
"claude-code"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
"max-tokens-3-5-sonnet-2024-07-15"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
result
:=
v
.
Validate
(
req
,
validClaudeCodeBody
())
require
.
True
(
t
,
result
,
"完整有效请求应通过"
)
}
func
TestValidate_MessagesPath_MissingHeaders
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
body
:=
validClaudeCodeBody
()
tests
:=
[]
struct
{
name
string
missingHeader
string
}{
{
"缺少 X-App"
,
"X-App"
},
{
"缺少 anthropic-beta"
,
"anthropic-beta"
},
{
"缺少 anthropic-version"
,
"anthropic-version"
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
req
:=
httptest
.
NewRequest
(
"POST"
,
"/v1/messages"
,
nil
)
req
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.0"
)
req
.
Header
.
Set
(
"X-App"
,
"claude-code"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
"beta"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
req
.
Header
.
Del
(
tt
.
missingHeader
)
result
:=
v
.
Validate
(
req
,
body
)
require
.
False
(
t
,
result
,
"缺少 %s 应返回 false"
,
tt
.
missingHeader
)
})
}
}
func
TestValidate_MessagesPath_InvalidMetadataUserID
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
tests
:=
[]
struct
{
name
string
metadata
map
[
string
]
any
}{
{
"缺少 metadata"
,
nil
},
{
"缺少 user_id"
,
map
[
string
]
any
{
"other"
:
"value"
}},
{
"空 user_id"
,
map
[
string
]
any
{
"user_id"
:
""
}},
{
"格式错误"
,
map
[
string
]
any
{
"user_id"
:
"invalid-format"
}},
{
"hex 长度不足"
,
map
[
string
]
any
{
"user_id"
:
"user_abc_account__session_uuid"
}},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
req
:=
httptest
.
NewRequest
(
"POST"
,
"/v1/messages"
,
nil
)
req
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.0"
)
req
.
Header
.
Set
(
"X-App"
,
"claude-code"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
"beta"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
body
:=
map
[
string
]
any
{
"model"
:
"claude-sonnet-4"
,
"system"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"You are Claude Code, Anthropic's official CLI for Claude."
,
},
},
}
if
tt
.
metadata
!=
nil
{
body
[
"metadata"
]
=
tt
.
metadata
}
result
:=
v
.
Validate
(
req
,
body
)
require
.
False
(
t
,
result
,
"metadata.user_id: %v"
,
tt
.
metadata
)
})
}
}
func
TestValidate_MessagesPath_InvalidSystemPrompt
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
req
:=
httptest
.
NewRequest
(
"POST"
,
"/v1/messages"
,
nil
)
req
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.0"
)
req
.
Header
.
Set
(
"X-App"
,
"claude-code"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
"beta"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
body
:=
map
[
string
]
any
{
"model"
:
"claude-sonnet-4"
,
"system"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"Generate JSON data for testing database migrations."
,
},
},
"metadata"
:
map
[
string
]
any
{
"user_id"
:
"user_"
+
"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
+
"_account__session_12345678-1234-1234-1234-123456789abc"
,
},
}
result
:=
v
.
Validate
(
req
,
body
)
require
.
False
(
t
,
result
,
"无关系统提示词应返回 false"
)
}
func
TestValidate_MaxTokensOneHaikuBypass
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
req
:=
httptest
.
NewRequest
(
"POST"
,
"/v1/messages"
,
nil
)
req
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.0"
)
// 不设置 X-App 等头,通过 context 标记为 haiku 探测请求
ctx
:=
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
IsMaxTokensOneHaikuRequest
,
true
)
req
=
req
.
WithContext
(
ctx
)
// 即使 body 不包含 system prompt,也应通过
result
:=
v
.
Validate
(
req
,
map
[
string
]
any
{
"model"
:
"claude-3-haiku"
,
"max_tokens"
:
1
})
require
.
True
(
t
,
result
,
"max_tokens=1+haiku 探测请求应绕过严格验证"
)
}
func
TestSystemPromptSimilarity
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
tests
:=
[]
struct
{
name
string
prompt
string
want
bool
}{
{
"精确匹配"
,
"You are Claude Code, Anthropic's official CLI for Claude."
,
true
},
{
"带多余空格"
,
"You are Claude Code, Anthropic's official CLI for Claude."
,
true
},
{
"Agent SDK 模板"
,
"You are a Claude agent, built on Anthropic's Claude Agent SDK."
,
true
},
{
"文件搜索专家模板"
,
"You are a file search specialist for Claude Code, Anthropic's official CLI for Claude."
,
true
},
{
"对话摘要模板"
,
"You are a helpful AI assistant tasked with summarizing conversations."
,
true
},
{
"交互式 CLI 模板"
,
"You are an interactive CLI tool that helps users"
,
true
},
{
"无关文本"
,
"Write me a poem about cats"
,
false
},
{
"空文本"
,
""
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
body
:=
map
[
string
]
any
{
"model"
:
"claude-sonnet-4"
,
"system"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
tt
.
prompt
},
},
}
result
:=
v
.
IncludesClaudeCodeSystemPrompt
(
body
)
require
.
Equal
(
t
,
tt
.
want
,
result
,
"提示词: %q"
,
tt
.
prompt
)
})
}
}
func
TestDiceCoefficient
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
a
string
b
string
want
float64
tol
float64
}{
{
"相同字符串"
,
"hello"
,
"hello"
,
1.0
,
0.001
},
{
"完全不同"
,
"abc"
,
"xyz"
,
0.0
,
0.001
},
{
"空字符串"
,
""
,
"hello"
,
0.0
,
0.001
},
{
"单字符"
,
"a"
,
"b"
,
0.0
,
0.001
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
diceCoefficient
(
tt
.
a
,
tt
.
b
)
require
.
InDelta
(
t
,
tt
.
want
,
result
,
tt
.
tol
)
})
}
}
func
TestIsClaudeCodeClient_Context
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
// 默认应为 false
require
.
False
(
t
,
IsClaudeCodeClient
(
ctx
))
// 设置为 true
ctx
=
SetClaudeCodeClient
(
ctx
,
true
)
require
.
True
(
t
,
IsClaudeCodeClient
(
ctx
))
// 设置为 false
ctx
=
SetClaudeCodeClient
(
ctx
,
false
)
require
.
False
(
t
,
IsClaudeCodeClient
(
ctx
))
}
func
TestValidate_NilBody_MessagesPath
(
t
*
testing
.
T
)
{
v
:=
newTestValidator
()
req
:=
httptest
.
NewRequest
(
"POST"
,
"/v1/messages"
,
nil
)
req
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.0"
)
req
.
Header
.
Set
(
"X-App"
,
"claude-code"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
"beta"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
result
:=
v
.
Validate
(
req
,
nil
)
require
.
False
(
t
,
result
,
"nil body 的 messages 请求应返回 false"
)
}
backend/internal/service/concurrency_service.go
View file @
6bccb8a8
...
@@ -5,8 +5,9 @@ import (
...
@@ -5,8 +5,9 @@ import (
"crypto/rand"
"crypto/rand"
"encoding/hex"
"encoding/hex"
"fmt"
"fmt"
"log"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
)
// ConcurrencyCache 定义并发控制的缓存接口
// ConcurrencyCache 定义并发控制的缓存接口
...
@@ -124,7 +125,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
...
@@ -124,7 +125,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
bgCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
bgCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
defer
cancel
()
if
err
:=
s
.
cache
.
ReleaseAccountSlot
(
bgCtx
,
accountID
,
requestID
);
err
!=
nil
{
if
err
:=
s
.
cache
.
ReleaseAccountSlot
(
bgCtx
,
accountID
,
requestID
);
err
!=
nil
{
log
.
Printf
(
"Warning: failed to release account slot for %d (req=%s): %v"
,
accountID
,
requestID
,
err
)
log
ger
.
LegacyPrintf
(
"service.concurrency"
,
"Warning: failed to release account slot for %d (req=%s): %v"
,
accountID
,
requestID
,
err
)
}
}
},
},
},
nil
},
nil
...
@@ -163,7 +164,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
...
@@ -163,7 +164,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
bgCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
bgCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
defer
cancel
()
if
err
:=
s
.
cache
.
ReleaseUserSlot
(
bgCtx
,
userID
,
requestID
);
err
!=
nil
{
if
err
:=
s
.
cache
.
ReleaseUserSlot
(
bgCtx
,
userID
,
requestID
);
err
!=
nil
{
log
.
Printf
(
"Warning: failed to release user slot for %d (req=%s): %v"
,
userID
,
requestID
,
err
)
log
ger
.
LegacyPrintf
(
"service.concurrency"
,
"Warning: failed to release user slot for %d (req=%s): %v"
,
userID
,
requestID
,
err
)
}
}
},
},
},
nil
},
nil
...
@@ -191,7 +192,7 @@ func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int6
...
@@ -191,7 +192,7 @@ func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int6
result
,
err
:=
s
.
cache
.
IncrementWaitCount
(
ctx
,
userID
,
maxWait
)
result
,
err
:=
s
.
cache
.
IncrementWaitCount
(
ctx
,
userID
,
maxWait
)
if
err
!=
nil
{
if
err
!=
nil
{
// On error, allow the request to proceed (fail open)
// On error, allow the request to proceed (fail open)
log
.
Printf
(
"Warning: increment wait count failed for user %d: %v"
,
userID
,
err
)
log
ger
.
LegacyPrintf
(
"service.concurrency"
,
"Warning: increment wait count failed for user %d: %v"
,
userID
,
err
)
return
true
,
nil
return
true
,
nil
}
}
return
result
,
nil
return
result
,
nil
...
@@ -209,7 +210,7 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
...
@@ -209,7 +210,7 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
defer
cancel
()
defer
cancel
()
if
err
:=
s
.
cache
.
DecrementWaitCount
(
bgCtx
,
userID
);
err
!=
nil
{
if
err
:=
s
.
cache
.
DecrementWaitCount
(
bgCtx
,
userID
);
err
!=
nil
{
log
.
Printf
(
"Warning: decrement wait count failed for user %d: %v"
,
userID
,
err
)
log
ger
.
LegacyPrintf
(
"service.concurrency"
,
"Warning: decrement wait count failed for user %d: %v"
,
userID
,
err
)
}
}
}
}
...
@@ -221,7 +222,7 @@ func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, acco
...
@@ -221,7 +222,7 @@ func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, acco
result
,
err
:=
s
.
cache
.
IncrementAccountWaitCount
(
ctx
,
accountID
,
maxWait
)
result
,
err
:=
s
.
cache
.
IncrementAccountWaitCount
(
ctx
,
accountID
,
maxWait
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Warning: increment wait count failed for account %d: %v"
,
accountID
,
err
)
log
ger
.
LegacyPrintf
(
"service.concurrency"
,
"Warning: increment wait count failed for account %d: %v"
,
accountID
,
err
)
return
true
,
nil
return
true
,
nil
}
}
return
result
,
nil
return
result
,
nil
...
@@ -237,7 +238,7 @@ func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, acco
...
@@ -237,7 +238,7 @@ func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, acco
defer
cancel
()
defer
cancel
()
if
err
:=
s
.
cache
.
DecrementAccountWaitCount
(
bgCtx
,
accountID
);
err
!=
nil
{
if
err
:=
s
.
cache
.
DecrementAccountWaitCount
(
bgCtx
,
accountID
);
err
!=
nil
{
log
.
Printf
(
"Warning: decrement wait count failed for account %d: %v"
,
accountID
,
err
)
log
ger
.
LegacyPrintf
(
"service.concurrency"
,
"Warning: decrement wait count failed for account %d: %v"
,
accountID
,
err
)
}
}
}
}
...
@@ -293,7 +294,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
...
@@ -293,7 +294,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
accounts
,
err
:=
accountRepo
.
ListSchedulable
(
listCtx
)
accounts
,
err
:=
accountRepo
.
ListSchedulable
(
listCtx
)
cancel
()
cancel
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Warning: list schedulable accounts failed: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.concurrency"
,
"Warning: list schedulable accounts failed: %v"
,
err
)
return
return
}
}
for
_
,
account
:=
range
accounts
{
for
_
,
account
:=
range
accounts
{
...
@@ -301,7 +302,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
...
@@ -301,7 +302,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
err
:=
s
.
cache
.
CleanupExpiredAccountSlots
(
accountCtx
,
account
.
ID
)
err
:=
s
.
cache
.
CleanupExpiredAccountSlots
(
accountCtx
,
account
.
ID
)
accountCancel
()
accountCancel
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Warning: cleanup expired slots failed for account %d: %v"
,
account
.
ID
,
err
)
log
ger
.
LegacyPrintf
(
"service.concurrency"
,
"Warning: cleanup expired slots failed for account %d: %v"
,
account
.
ID
,
err
)
}
}
}
}
}
}
...
...
backend/internal/service/concurrency_service_test.go
0 → 100644
View file @
6bccb8a8
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
type
stubConcurrencyCacheForTest
struct
{
acquireResult
bool
acquireErr
error
releaseErr
error
concurrency
int
concurrencyErr
error
waitAllowed
bool
waitErr
error
waitCount
int
waitCountErr
error
loadBatch
map
[
int64
]
*
AccountLoadInfo
loadBatchErr
error
usersLoadBatch
map
[
int64
]
*
UserLoadInfo
usersLoadErr
error
cleanupErr
error
// 记录调用
releasedAccountIDs
[]
int64
releasedRequestIDs
[]
string
}
var
_
ConcurrencyCache
=
(
*
stubConcurrencyCacheForTest
)(
nil
)
func
(
c
*
stubConcurrencyCacheForTest
)
AcquireAccountSlot
(
_
context
.
Context
,
_
int64
,
_
int
,
_
string
)
(
bool
,
error
)
{
return
c
.
acquireResult
,
c
.
acquireErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
ReleaseAccountSlot
(
_
context
.
Context
,
accountID
int64
,
requestID
string
)
error
{
c
.
releasedAccountIDs
=
append
(
c
.
releasedAccountIDs
,
accountID
)
c
.
releasedRequestIDs
=
append
(
c
.
releasedRequestIDs
,
requestID
)
return
c
.
releaseErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
GetAccountConcurrency
(
_
context
.
Context
,
_
int64
)
(
int
,
error
)
{
return
c
.
concurrency
,
c
.
concurrencyErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
IncrementAccountWaitCount
(
_
context
.
Context
,
_
int64
,
_
int
)
(
bool
,
error
)
{
return
c
.
waitAllowed
,
c
.
waitErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
DecrementAccountWaitCount
(
_
context
.
Context
,
_
int64
)
error
{
return
nil
}
func
(
c
*
stubConcurrencyCacheForTest
)
GetAccountWaitingCount
(
_
context
.
Context
,
_
int64
)
(
int
,
error
)
{
return
c
.
waitCount
,
c
.
waitCountErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
AcquireUserSlot
(
_
context
.
Context
,
_
int64
,
_
int
,
_
string
)
(
bool
,
error
)
{
return
c
.
acquireResult
,
c
.
acquireErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
ReleaseUserSlot
(
_
context
.
Context
,
_
int64
,
_
string
)
error
{
return
c
.
releaseErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
GetUserConcurrency
(
_
context
.
Context
,
_
int64
)
(
int
,
error
)
{
return
c
.
concurrency
,
c
.
concurrencyErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
IncrementWaitCount
(
_
context
.
Context
,
_
int64
,
_
int
)
(
bool
,
error
)
{
return
c
.
waitAllowed
,
c
.
waitErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
DecrementWaitCount
(
_
context
.
Context
,
_
int64
)
error
{
return
nil
}
func
(
c
*
stubConcurrencyCacheForTest
)
GetAccountsLoadBatch
(
_
context
.
Context
,
_
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
return
c
.
loadBatch
,
c
.
loadBatchErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
GetUsersLoadBatch
(
_
context
.
Context
,
_
[]
UserWithConcurrency
)
(
map
[
int64
]
*
UserLoadInfo
,
error
)
{
return
c
.
usersLoadBatch
,
c
.
usersLoadErr
}
func
(
c
*
stubConcurrencyCacheForTest
)
CleanupExpiredAccountSlots
(
_
context
.
Context
,
_
int64
)
error
{
return
c
.
cleanupErr
}
func
TestAcquireAccountSlot_Success
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
acquireResult
:
true
}
svc
:=
NewConcurrencyService
(
cache
)
result
,
err
:=
svc
.
AcquireAccountSlot
(
context
.
Background
(),
1
,
5
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Acquired
)
require
.
NotNil
(
t
,
result
.
ReleaseFunc
)
}
func
TestAcquireAccountSlot_Failure
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
acquireResult
:
false
}
svc
:=
NewConcurrencyService
(
cache
)
result
,
err
:=
svc
.
AcquireAccountSlot
(
context
.
Background
(),
1
,
5
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
result
.
Acquired
)
require
.
Nil
(
t
,
result
.
ReleaseFunc
)
}
func
TestAcquireAccountSlot_UnlimitedConcurrency
(
t
*
testing
.
T
)
{
svc
:=
NewConcurrencyService
(
&
stubConcurrencyCacheForTest
{})
for
_
,
maxConcurrency
:=
range
[]
int
{
0
,
-
1
}
{
result
,
err
:=
svc
.
AcquireAccountSlot
(
context
.
Background
(),
1
,
maxConcurrency
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Acquired
,
"maxConcurrency=%d 应无限制通过"
,
maxConcurrency
)
require
.
NotNil
(
t
,
result
.
ReleaseFunc
,
"ReleaseFunc 应为 no-op 函数"
)
}
}
func
TestAcquireAccountSlot_CacheError
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
acquireErr
:
errors
.
New
(
"redis down"
)}
svc
:=
NewConcurrencyService
(
cache
)
result
,
err
:=
svc
.
AcquireAccountSlot
(
context
.
Background
(),
1
,
5
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
}
func
TestAcquireAccountSlot_ReleaseDecrements
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
acquireResult
:
true
}
svc
:=
NewConcurrencyService
(
cache
)
result
,
err
:=
svc
.
AcquireAccountSlot
(
context
.
Background
(),
42
,
5
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Acquired
)
// 调用 ReleaseFunc 应释放槽位
result
.
ReleaseFunc
()
require
.
Len
(
t
,
cache
.
releasedAccountIDs
,
1
)
require
.
Equal
(
t
,
int64
(
42
),
cache
.
releasedAccountIDs
[
0
])
require
.
Len
(
t
,
cache
.
releasedRequestIDs
,
1
)
require
.
NotEmpty
(
t
,
cache
.
releasedRequestIDs
[
0
],
"requestID 不应为空"
)
}
func
TestAcquireUserSlot_IndependentFromAccount
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
acquireResult
:
true
}
svc
:=
NewConcurrencyService
(
cache
)
// 用户槽位获取应独立于账户槽位
result
,
err
:=
svc
.
AcquireUserSlot
(
context
.
Background
(),
100
,
3
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Acquired
)
require
.
NotNil
(
t
,
result
.
ReleaseFunc
)
}
func
TestAcquireUserSlot_UnlimitedConcurrency
(
t
*
testing
.
T
)
{
svc
:=
NewConcurrencyService
(
&
stubConcurrencyCacheForTest
{})
result
,
err
:=
svc
.
AcquireUserSlot
(
context
.
Background
(),
1
,
0
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Acquired
)
}
func
TestGetAccountsLoadBatch_ReturnsCorrectData
(
t
*
testing
.
T
)
{
expected
:=
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
CurrentConcurrency
:
3
,
WaitingCount
:
0
,
LoadRate
:
60
},
2
:
{
AccountID
:
2
,
CurrentConcurrency
:
5
,
WaitingCount
:
2
,
LoadRate
:
100
},
}
cache
:=
&
stubConcurrencyCacheForTest
{
loadBatch
:
expected
}
svc
:=
NewConcurrencyService
(
cache
)
accounts
:=
[]
AccountWithConcurrency
{
{
ID
:
1
,
MaxConcurrency
:
5
},
{
ID
:
2
,
MaxConcurrency
:
5
},
}
result
,
err
:=
svc
.
GetAccountsLoadBatch
(
context
.
Background
(),
accounts
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
expected
,
result
)
}
func
TestGetAccountsLoadBatch_NilCache
(
t
*
testing
.
T
)
{
svc
:=
&
ConcurrencyService
{
cache
:
nil
}
result
,
err
:=
svc
.
GetAccountsLoadBatch
(
context
.
Background
(),
nil
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
result
)
}
func
TestIncrementWaitCount_Success
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
waitAllowed
:
true
}
svc
:=
NewConcurrencyService
(
cache
)
allowed
,
err
:=
svc
.
IncrementWaitCount
(
context
.
Background
(),
1
,
25
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
allowed
)
}
func
TestIncrementWaitCount_QueueFull
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
waitAllowed
:
false
}
svc
:=
NewConcurrencyService
(
cache
)
allowed
,
err
:=
svc
.
IncrementWaitCount
(
context
.
Background
(),
1
,
25
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
allowed
)
}
func
TestIncrementWaitCount_FailOpen
(
t
*
testing
.
T
)
{
// Redis 错误时应 fail-open(允许请求通过)
cache
:=
&
stubConcurrencyCacheForTest
{
waitErr
:
errors
.
New
(
"redis timeout"
)}
svc
:=
NewConcurrencyService
(
cache
)
allowed
,
err
:=
svc
.
IncrementWaitCount
(
context
.
Background
(),
1
,
25
)
require
.
NoError
(
t
,
err
,
"Redis 错误不应传播"
)
require
.
True
(
t
,
allowed
,
"Redis 错误时应 fail-open"
)
}
func
TestIncrementWaitCount_NilCache
(
t
*
testing
.
T
)
{
svc
:=
&
ConcurrencyService
{
cache
:
nil
}
allowed
,
err
:=
svc
.
IncrementWaitCount
(
context
.
Background
(),
1
,
25
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
allowed
,
"nil cache 应 fail-open"
)
}
func
TestCalculateMaxWait
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
concurrency
int
expected
int
}{
{
5
,
25
},
// 5 + 20
{
1
,
21
},
// 1 + 20
{
0
,
21
},
// min(1) + 20
{
-
1
,
21
},
// min(1) + 20
{
10
,
30
},
// 10 + 20
}
for
_
,
tt
:=
range
tests
{
result
:=
CalculateMaxWait
(
tt
.
concurrency
)
require
.
Equal
(
t
,
tt
.
expected
,
result
,
"CalculateMaxWait(%d)"
,
tt
.
concurrency
)
}
}
func
TestGetAccountWaitingCount
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
waitCount
:
5
}
svc
:=
NewConcurrencyService
(
cache
)
count
,
err
:=
svc
.
GetAccountWaitingCount
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
5
,
count
)
}
func
TestGetAccountWaitingCount_NilCache
(
t
*
testing
.
T
)
{
svc
:=
&
ConcurrencyService
{
cache
:
nil
}
count
,
err
:=
svc
.
GetAccountWaitingCount
(
context
.
Background
(),
1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
0
,
count
)
}
func
TestGetAccountConcurrencyBatch
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
concurrency
:
3
}
svc
:=
NewConcurrencyService
(
cache
)
result
,
err
:=
svc
.
GetAccountConcurrencyBatch
(
context
.
Background
(),
[]
int64
{
1
,
2
,
3
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
result
,
3
)
for
_
,
id
:=
range
[]
int64
{
1
,
2
,
3
}
{
require
.
Equal
(
t
,
3
,
result
[
id
])
}
}
func
TestIncrementAccountWaitCount_FailOpen
(
t
*
testing
.
T
)
{
cache
:=
&
stubConcurrencyCacheForTest
{
waitErr
:
errors
.
New
(
"redis error"
)}
svc
:=
NewConcurrencyService
(
cache
)
allowed
,
err
:=
svc
.
IncrementAccountWaitCount
(
context
.
Background
(),
1
,
10
)
require
.
NoError
(
t
,
err
,
"Redis 错误不应传播"
)
require
.
True
(
t
,
allowed
,
"Redis 错误时应 fail-open"
)
}
func
TestIncrementAccountWaitCount_NilCache
(
t
*
testing
.
T
)
{
svc
:=
&
ConcurrencyService
{
cache
:
nil
}
allowed
,
err
:=
svc
.
IncrementAccountWaitCount
(
context
.
Background
(),
1
,
10
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
allowed
)
}
backend/internal/service/dashboard_aggregation_service.go
View file @
6bccb8a8
...
@@ -3,11 +3,12 @@ package service
...
@@ -3,11 +3,12 @@ package service
import
(
import
(
"context"
"context"
"errors"
"errors"
"log"
"log
/slog
"
"sync/atomic"
"sync/atomic"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
)
const
(
const
(
...
@@ -65,7 +66,7 @@ func (s *DashboardAggregationService) Start() {
...
@@ -65,7 +66,7 @@ func (s *DashboardAggregationService) Start() {
return
return
}
}
if
!
s
.
cfg
.
Enabled
{
if
!
s
.
cfg
.
Enabled
{
log
.
Printf
(
"[DashboardAggregation] 聚合作业已禁用"
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 聚合作业已禁用"
)
return
return
}
}
...
@@ -81,9 +82,9 @@ func (s *DashboardAggregationService) Start() {
...
@@ -81,9 +82,9 @@ func (s *DashboardAggregationService) Start() {
s
.
timingWheel
.
ScheduleRecurring
(
"dashboard:aggregation"
,
interval
,
func
()
{
s
.
timingWheel
.
ScheduleRecurring
(
"dashboard:aggregation"
,
interval
,
func
()
{
s
.
runScheduledAggregation
()
s
.
runScheduledAggregation
()
})
})
log
.
Printf
(
"[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)"
,
interval
,
s
.
cfg
.
LookbackSeconds
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)"
,
interval
,
s
.
cfg
.
LookbackSeconds
)
if
!
s
.
cfg
.
BackfillEnabled
{
if
!
s
.
cfg
.
BackfillEnabled
{
log
.
Printf
(
"[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填"
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填"
)
}
}
}
}
...
@@ -93,7 +94,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
...
@@ -93,7 +94,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
return
errors
.
New
(
"聚合服务未初始化"
)
return
errors
.
New
(
"聚合服务未初始化"
)
}
}
if
!
s
.
cfg
.
BackfillEnabled
{
if
!
s
.
cfg
.
BackfillEnabled
{
log
.
Printf
(
"[DashboardAggregation] 回填被拒绝: backfill_enabled=false"
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 回填被拒绝: backfill_enabled=false"
)
return
ErrDashboardBackfillDisabled
return
ErrDashboardBackfillDisabled
}
}
if
!
end
.
After
(
start
)
{
if
!
end
.
After
(
start
)
{
...
@@ -110,7 +111,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
...
@@ -110,7 +111,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
defaultDashboardAggregationBackfillTimeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
defaultDashboardAggregationBackfillTimeout
)
defer
cancel
()
defer
cancel
()
if
err
:=
s
.
backfillRange
(
ctx
,
start
,
end
);
err
!=
nil
{
if
err
:=
s
.
backfillRange
(
ctx
,
start
,
end
);
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 回填失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 回填失败: %v"
,
err
)
}
}
}()
}()
return
nil
return
nil
...
@@ -141,12 +142,12 @@ func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time
...
@@ -141,12 +142,12 @@ func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time
return
return
}
}
if
!
errors
.
Is
(
err
,
errDashboardAggregationRunning
)
{
if
!
errors
.
Is
(
err
,
errDashboardAggregationRunning
)
{
log
.
Printf
(
"[DashboardAggregation] 重新计算失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 重新计算失败: %v"
,
err
)
return
return
}
}
time
.
Sleep
(
5
*
time
.
Second
)
time
.
Sleep
(
5
*
time
.
Second
)
}
}
log
.
Printf
(
"[DashboardAggregation] 重新计算放弃: 聚合作业持续占用"
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 重新计算放弃: 聚合作业持续占用"
)
}()
}()
return
nil
return
nil
}
}
...
@@ -162,7 +163,7 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
...
@@ -162,7 +163,7 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
defaultDashboardAggregationBackfillTimeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
defaultDashboardAggregationBackfillTimeout
)
defer
cancel
()
defer
cancel
()
if
err
:=
s
.
backfillRange
(
ctx
,
start
,
now
);
err
!=
nil
{
if
err
:=
s
.
backfillRange
(
ctx
,
start
,
now
);
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 启动重算失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 启动重算失败: %v"
,
err
)
return
return
}
}
}
}
...
@@ -177,7 +178,7 @@ func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start,
...
@@ -177,7 +178,7 @@ func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start,
if
err
:=
s
.
repo
.
RecomputeRange
(
ctx
,
start
,
end
);
err
!=
nil
{
if
err
:=
s
.
repo
.
RecomputeRange
(
ctx
,
start
,
end
);
err
!=
nil
{
return
err
return
err
}
}
log
.
Printf
(
"[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)"
,
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)"
,
start
.
UTC
()
.
Format
(
time
.
RFC3339
),
start
.
UTC
()
.
Format
(
time
.
RFC3339
),
end
.
UTC
()
.
Format
(
time
.
RFC3339
),
end
.
UTC
()
.
Format
(
time
.
RFC3339
),
time
.
Since
(
jobStart
)
.
String
(),
time
.
Since
(
jobStart
)
.
String
(),
...
@@ -198,7 +199,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
...
@@ -198,7 +199,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
now
:=
time
.
Now
()
.
UTC
()
now
:=
time
.
Now
()
.
UTC
()
last
,
err
:=
s
.
repo
.
GetAggregationWatermark
(
ctx
)
last
,
err
:=
s
.
repo
.
GetAggregationWatermark
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 读取水位失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 读取水位失败: %v"
,
err
)
last
=
time
.
Unix
(
0
,
0
)
.
UTC
()
last
=
time
.
Unix
(
0
,
0
)
.
UTC
()
}
}
...
@@ -216,19 +217,19 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
...
@@ -216,19 +217,19 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
}
}
if
err
:=
s
.
aggregateRange
(
ctx
,
start
,
now
);
err
!=
nil
{
if
err
:=
s
.
aggregateRange
(
ctx
,
start
,
now
);
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 聚合失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 聚合失败: %v"
,
err
)
return
return
}
}
updateErr
:=
s
.
repo
.
UpdateAggregationWatermark
(
ctx
,
now
)
updateErr
:=
s
.
repo
.
UpdateAggregationWatermark
(
ctx
,
now
)
if
updateErr
!=
nil
{
if
updateErr
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 更新水位失败: %v"
,
updateErr
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 更新水位失败: %v"
,
updateErr
)
}
}
log
.
Printf
(
"[DashboardAggregation] 聚合完成
(start=%s end=%s duration=%s watermark_updated=%t)
"
,
s
log
.
Debug
(
"[DashboardAggregation] 聚合完成"
,
start
.
Format
(
time
.
RFC3339
),
"start"
,
start
.
Format
(
time
.
RFC3339
),
now
.
Format
(
time
.
RFC3339
),
"end"
,
now
.
Format
(
time
.
RFC3339
),
time
.
Since
(
jobStart
)
.
String
(),
"duration"
,
time
.
Since
(
jobStart
)
.
String
(),
updateErr
==
nil
,
"watermark_updated"
,
updateErr
==
nil
,
)
)
s
.
maybeCleanupRetention
(
ctx
,
now
)
s
.
maybeCleanupRetention
(
ctx
,
now
)
...
@@ -261,9 +262,9 @@ func (s *DashboardAggregationService) backfillRange(ctx context.Context, start,
...
@@ -261,9 +262,9 @@ func (s *DashboardAggregationService) backfillRange(ctx context.Context, start,
updateErr
:=
s
.
repo
.
UpdateAggregationWatermark
(
ctx
,
endUTC
)
updateErr
:=
s
.
repo
.
UpdateAggregationWatermark
(
ctx
,
endUTC
)
if
updateErr
!=
nil
{
if
updateErr
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 更新水位失败: %v"
,
updateErr
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 更新水位失败: %v"
,
updateErr
)
}
}
log
.
Printf
(
"[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)"
,
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)"
,
startUTC
.
Format
(
time
.
RFC3339
),
startUTC
.
Format
(
time
.
RFC3339
),
endUTC
.
Format
(
time
.
RFC3339
),
endUTC
.
Format
(
time
.
RFC3339
),
time
.
Since
(
jobStart
)
.
String
(),
time
.
Since
(
jobStart
)
.
String
(),
...
@@ -279,7 +280,7 @@ func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start,
...
@@ -279,7 +280,7 @@ func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start,
return
nil
return
nil
}
}
if
err
:=
s
.
repo
.
EnsureUsageLogsPartitions
(
ctx
,
end
);
err
!=
nil
{
if
err
:=
s
.
repo
.
EnsureUsageLogsPartitions
(
ctx
,
end
);
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 分区检查失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 分区检查失败: %v"
,
err
)
}
}
return
s
.
repo
.
AggregateRange
(
ctx
,
start
,
end
)
return
s
.
repo
.
AggregateRange
(
ctx
,
start
,
end
)
}
}
...
@@ -298,11 +299,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
...
@@ -298,11 +299,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
aggErr
:=
s
.
repo
.
CleanupAggregates
(
ctx
,
hourlyCutoff
,
dailyCutoff
)
aggErr
:=
s
.
repo
.
CleanupAggregates
(
ctx
,
hourlyCutoff
,
dailyCutoff
)
if
aggErr
!=
nil
{
if
aggErr
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 聚合保留清理失败: %v"
,
aggErr
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] 聚合保留清理失败: %v"
,
aggErr
)
}
}
usageErr
:=
s
.
repo
.
CleanupUsageLogs
(
ctx
,
usageCutoff
)
usageErr
:=
s
.
repo
.
CleanupUsageLogs
(
ctx
,
usageCutoff
)
if
usageErr
!=
nil
{
if
usageErr
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] usage_logs 保留清理失败: %v"
,
usageErr
)
log
ger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] usage_logs 保留清理失败: %v"
,
usageErr
)
}
}
if
aggErr
==
nil
&&
usageErr
==
nil
{
if
aggErr
==
nil
&&
usageErr
==
nil
{
s
.
lastRetentionCleanup
.
Store
(
now
)
s
.
lastRetentionCleanup
.
Store
(
now
)
...
...
backend/internal/service/dashboard_service.go
View file @
6bccb8a8
...
@@ -5,11 +5,11 @@ import (
...
@@ -5,11 +5,11 @@ import (
"encoding/json"
"encoding/json"
"errors"
"errors"
"fmt"
"fmt"
"log"
"sync/atomic"
"sync/atomic"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
)
...
@@ -113,7 +113,7 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
...
@@ -113,7 +113,7 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return
cached
,
nil
return
cached
,
nil
}
}
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
ErrDashboardStatsCacheMiss
)
{
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
ErrDashboardStatsCacheMiss
)
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存读取失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard"
,
"[Dashboard] 仪表盘缓存读取失败: %v"
,
err
)
}
}
}
}
...
@@ -188,7 +188,7 @@ func (s *DashboardService) refreshDashboardStatsAsync() {
...
@@ -188,7 +188,7 @@ func (s *DashboardService) refreshDashboardStatsAsync() {
stats
,
err
:=
s
.
fetchDashboardStats
(
ctx
)
stats
,
err
:=
s
.
fetchDashboardStats
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存异步刷新失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard"
,
"[Dashboard] 仪表盘缓存异步刷新失败: %v"
,
err
)
return
return
}
}
s
.
applyAggregationStatus
(
ctx
,
stats
)
s
.
applyAggregationStatus
(
ctx
,
stats
)
...
@@ -220,12 +220,12 @@ func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *u
...
@@ -220,12 +220,12 @@ func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *u
}
}
data
,
err
:=
json
.
Marshal
(
entry
)
data
,
err
:=
json
.
Marshal
(
entry
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存序列化失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard"
,
"[Dashboard] 仪表盘缓存序列化失败: %v"
,
err
)
return
return
}
}
if
err
:=
s
.
cache
.
SetDashboardStats
(
ctx
,
string
(
data
),
s
.
cacheTTL
);
err
!=
nil
{
if
err
:=
s
.
cache
.
SetDashboardStats
(
ctx
,
string
(
data
),
s
.
cacheTTL
);
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存写入失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard"
,
"[Dashboard] 仪表盘缓存写入失败: %v"
,
err
)
}
}
}
}
...
@@ -237,10 +237,10 @@ func (s *DashboardService) evictDashboardStatsCache(reason error) {
...
@@ -237,10 +237,10 @@ func (s *DashboardService) evictDashboardStatsCache(reason error) {
defer
cancel
()
defer
cancel
()
if
err
:=
s
.
cache
.
DeleteDashboardStats
(
cacheCtx
);
err
!=
nil
{
if
err
:=
s
.
cache
.
DeleteDashboardStats
(
cacheCtx
);
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存清理失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard"
,
"[Dashboard] 仪表盘缓存清理失败: %v"
,
err
)
}
}
if
reason
!=
nil
{
if
reason
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存异常,已清理: %v"
,
reason
)
log
ger
.
LegacyPrintf
(
"service.dashboard"
,
"[Dashboard] 仪表盘缓存异常,已清理: %v"
,
reason
)
}
}
}
}
...
@@ -271,7 +271,7 @@ func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.T
...
@@ -271,7 +271,7 @@ func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.T
}
}
updatedAt
,
err
:=
s
.
aggRepo
.
GetAggregationWatermark
(
ctx
)
updatedAt
,
err
:=
s
.
aggRepo
.
GetAggregationWatermark
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 读取聚合水位失败: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.dashboard"
,
"[Dashboard] 读取聚合水位失败: %v"
,
err
)
return
time
.
Unix
(
0
,
0
)
.
UTC
()
return
time
.
Unix
(
0
,
0
)
.
UTC
()
}
}
if
updatedAt
.
IsZero
()
{
if
updatedAt
.
IsZero
()
{
...
@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
...
@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return
trend
,
nil
return
trend
,
nil
}
}
func
(
s
*
DashboardService
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
func
(
s
*
DashboardService
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchUserUsageStats
(
ctx
,
userIDs
)
stats
,
err
:=
s
.
usageRepo
.
GetBatchUserUsageStats
(
ctx
,
userIDs
,
startTime
,
endTime
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get batch user usage stats: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get batch user usage stats: %w"
,
err
)
}
}
return
stats
,
nil
return
stats
,
nil
}
}
func
(
s
*
DashboardService
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
func
(
s
*
DashboardService
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchAPIKeyUsageStats
(
ctx
,
apiKeyIDs
)
stats
,
err
:=
s
.
usageRepo
.
GetBatchAPIKeyUsageStats
(
ctx
,
apiKeyIDs
,
startTime
,
endTime
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get batch api key usage stats: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get batch api key usage stats: %w"
,
err
)
}
}
...
...
backend/internal/service/domain_constants.go
View file @
6bccb8a8
...
@@ -24,6 +24,7 @@ const (
...
@@ -24,6 +24,7 @@ const (
PlatformOpenAI
=
domain
.
PlatformOpenAI
PlatformOpenAI
=
domain
.
PlatformOpenAI
PlatformGemini
=
domain
.
PlatformGemini
PlatformGemini
=
domain
.
PlatformGemini
PlatformAntigravity
=
domain
.
PlatformAntigravity
PlatformAntigravity
=
domain
.
PlatformAntigravity
PlatformSora
=
domain
.
PlatformSora
)
)
// Account type constants
// Account type constants
...
@@ -160,6 +161,9 @@ const (
...
@@ -160,6 +161,9 @@ const (
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
SettingKeyOpsAdvancedSettings
=
"ops_advanced_settings"
SettingKeyOpsAdvancedSettings
=
"ops_advanced_settings"
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
SettingKeyOpsRuntimeLogConfig
=
"ops_runtime_log_config"
// =========================
// =========================
// Stream Timeout Handling
// Stream Timeout Handling
// =========================
// =========================
...
...
backend/internal/service/email_queue_service.go
View file @
6bccb8a8
...
@@ -3,9 +3,10 @@ package service
...
@@ -3,9 +3,10 @@ package service
import
(
import
(
"context"
"context"
"fmt"
"fmt"
"log"
"sync"
"sync"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
)
// Task type constants
// Task type constants
...
@@ -56,7 +57,7 @@ func (s *EmailQueueService) start() {
...
@@ -56,7 +57,7 @@ func (s *EmailQueueService) start() {
s
.
wg
.
Add
(
1
)
s
.
wg
.
Add
(
1
)
go
s
.
worker
(
i
)
go
s
.
worker
(
i
)
}
}
log
.
Printf
(
"[EmailQueue] Started %d workers"
,
s
.
workers
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"[EmailQueue] Started %d workers"
,
s
.
workers
)
}
}
// worker 工作协程
// worker 工作协程
...
@@ -68,7 +69,7 @@ func (s *EmailQueueService) worker(id int) {
...
@@ -68,7 +69,7 @@ func (s *EmailQueueService) worker(id int) {
case
task
:=
<-
s
.
taskChan
:
case
task
:=
<-
s
.
taskChan
:
s
.
processTask
(
id
,
task
)
s
.
processTask
(
id
,
task
)
case
<-
s
.
stopChan
:
case
<-
s
.
stopChan
:
log
.
Printf
(
"[EmailQueue] Worker %d stopping"
,
id
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"[EmailQueue] Worker %d stopping"
,
id
)
return
return
}
}
}
}
...
@@ -82,18 +83,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
...
@@ -82,18 +83,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
switch
task
.
TaskType
{
switch
task
.
TaskType
{
case
TaskTypeVerifyCode
:
case
TaskTypeVerifyCode
:
if
err
:=
s
.
emailService
.
SendVerifyCode
(
ctx
,
task
.
Email
,
task
.
SiteName
);
err
!=
nil
{
if
err
:=
s
.
emailService
.
SendVerifyCode
(
ctx
,
task
.
Email
,
task
.
SiteName
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send verify code to %s: %v"
,
workerID
,
task
.
Email
,
err
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"[EmailQueue] Worker %d failed to send verify code to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent verify code to %s"
,
workerID
,
task
.
Email
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"[EmailQueue] Worker %d sent verify code to %s"
,
workerID
,
task
.
Email
)
}
}
case
TaskTypePasswordReset
:
case
TaskTypePasswordReset
:
if
err
:=
s
.
emailService
.
SendPasswordResetEmailWithCooldown
(
ctx
,
task
.
Email
,
task
.
SiteName
,
task
.
ResetURL
);
err
!=
nil
{
if
err
:=
s
.
emailService
.
SendPasswordResetEmailWithCooldown
(
ctx
,
task
.
Email
,
task
.
SiteName
,
task
.
ResetURL
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send password reset to %s: %v"
,
workerID
,
task
.
Email
,
err
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"[EmailQueue] Worker %d failed to send password reset to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent password reset to %s"
,
workerID
,
task
.
Email
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"[EmailQueue] Worker %d sent password reset to %s"
,
workerID
,
task
.
Email
)
}
}
default
:
default
:
log
.
Printf
(
"[EmailQueue] Worker %d unknown task type: %s"
,
workerID
,
task
.
TaskType
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"[EmailQueue] Worker %d unknown task type: %s"
,
workerID
,
task
.
TaskType
)
}
}
}
}
...
@@ -107,7 +108,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
...
@@ -107,7 +108,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
select
{
select
{
case
s
.
taskChan
<-
task
:
case
s
.
taskChan
<-
task
:
log
.
Printf
(
"[EmailQueue] Enqueued verify code task for %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"[EmailQueue] Enqueued verify code task for %s"
,
email
)
return
nil
return
nil
default
:
default
:
return
fmt
.
Errorf
(
"email queue is full"
)
return
fmt
.
Errorf
(
"email queue is full"
)
...
@@ -125,7 +126,7 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
...
@@ -125,7 +126,7 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
select
{
select
{
case
s
.
taskChan
<-
task
:
case
s
.
taskChan
<-
task
:
log
.
Printf
(
"[EmailQueue] Enqueued password reset task for %s"
,
email
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"[EmailQueue] Enqueued password reset task for %s"
,
email
)
return
nil
return
nil
default
:
default
:
return
fmt
.
Errorf
(
"email queue is full"
)
return
fmt
.
Errorf
(
"email queue is full"
)
...
@@ -136,5 +137,5 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
...
@@ -136,5 +137,5 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin
func
(
s
*
EmailQueueService
)
Stop
()
{
func
(
s
*
EmailQueueService
)
Stop
()
{
close
(
s
.
stopChan
)
close
(
s
.
stopChan
)
s
.
wg
.
Wait
()
s
.
wg
.
Wait
()
log
.
Println
(
"[EmailQueue] All workers stopped"
)
log
ger
.
LegacyPrintf
(
"service.email_queue"
,
"%s"
,
"[EmailQueue] All workers stopped"
)
}
}
backend/internal/service/error_passthrough_runtime_test.go
View file @
6bccb8a8
...
@@ -76,7 +76,7 @@ func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
...
@@ -76,7 +76,7 @@ func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
}
}
account
:=
&
Account
{
ID
:
12
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
account
:=
&
Account
{
ID
:
12
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
_
,
err
:=
svc
.
handleErrorResponse
(
context
.
Background
(),
resp
,
c
,
account
)
_
,
err
:=
svc
.
handleErrorResponse
(
context
.
Background
(),
resp
,
c
,
account
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Error
(
t
,
err
)
assert
.
Equal
(
t
,
http
.
StatusBadGateway
,
rec
.
Code
)
assert
.
Equal
(
t
,
http
.
StatusBadGateway
,
rec
.
Code
)
...
@@ -157,7 +157,7 @@ func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
...
@@ -157,7 +157,7 @@ func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
}
}
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
_
,
err
:=
svc
.
handleErrorResponse
(
context
.
Background
(),
resp
,
c
,
account
)
_
,
err
:=
svc
.
handleErrorResponse
(
context
.
Background
(),
resp
,
c
,
account
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Error
(
t
,
err
)
assert
.
Equal
(
t
,
http
.
StatusTeapot
,
rec
.
Code
)
assert
.
Equal
(
t
,
http
.
StatusTeapot
,
rec
.
Code
)
...
...
backend/internal/service/error_passthrough_service.go
View file @
6bccb8a8
...
@@ -2,13 +2,13 @@ package service
...
@@ -2,13 +2,13 @@ package service
import
(
import
(
"context"
"context"
"log"
"sort"
"sort"
"strings"
"strings"
"sync"
"sync"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
)
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
...
@@ -72,9 +72,9 @@ func NewErrorPassthroughService(
...
@@ -72,9 +72,9 @@ func NewErrorPassthroughService(
// 启动时加载规则到本地缓存
// 启动时加载规则到本地缓存
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
if
err
:=
svc
.
reloadRulesFromDB
(
ctx
);
err
!=
nil
{
if
err
:=
svc
.
reloadRulesFromDB
(
ctx
);
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to load rules from DB on startup: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.error_passthrough"
,
"[ErrorPassthroughService] Failed to load rules from DB on startup: %v"
,
err
)
if
fallbackErr
:=
svc
.
refreshLocalCache
(
ctx
);
fallbackErr
!=
nil
{
if
fallbackErr
:=
svc
.
refreshLocalCache
(
ctx
);
fallbackErr
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v"
,
fallbackErr
)
log
ger
.
LegacyPrintf
(
"service.error_passthrough"
,
"[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v"
,
fallbackErr
)
}
}
}
}
...
@@ -82,7 +82,7 @@ func NewErrorPassthroughService(
...
@@ -82,7 +82,7 @@ func NewErrorPassthroughService(
if
cache
!=
nil
{
if
cache
!=
nil
{
cache
.
SubscribeUpdates
(
ctx
,
func
()
{
cache
.
SubscribeUpdates
(
ctx
,
func
()
{
if
err
:=
svc
.
refreshLocalCache
(
context
.
Background
());
err
!=
nil
{
if
err
:=
svc
.
refreshLocalCache
(
context
.
Background
());
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to refresh cache on notification: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.error_passthrough"
,
"[ErrorPassthroughService] Failed to refresh cache on notification: %v"
,
err
)
}
}
})
})
}
}
...
@@ -192,7 +192,7 @@ func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
...
@@ -192,7 +192,7 @@ func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
// 如果本地缓存为空,尝试刷新
// 如果本地缓存为空,尝试刷新
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
if
err
:=
s
.
refreshLocalCache
(
ctx
);
err
!=
nil
{
if
err
:=
s
.
refreshLocalCache
(
ctx
);
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to refresh cache: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.error_passthrough"
,
"[ErrorPassthroughService] Failed to refresh cache: %v"
,
err
)
return
nil
return
nil
}
}
...
@@ -225,7 +225,7 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
...
@@ -225,7 +225,7 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
// 更新 Redis 缓存
// 更新 Redis 缓存
if
s
.
cache
!=
nil
{
if
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
Set
(
ctx
,
rules
);
err
!=
nil
{
if
err
:=
s
.
cache
.
Set
(
ctx
,
rules
);
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to set cache: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.error_passthrough"
,
"[ErrorPassthroughService] Failed to set cache: %v"
,
err
)
}
}
}
}
...
@@ -288,13 +288,13 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
...
@@ -288,13 +288,13 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 先失效缓存,避免后续刷新读到陈旧规则。
// 先失效缓存,避免后续刷新读到陈旧规则。
if
s
.
cache
!=
nil
{
if
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
Invalidate
(
ctx
);
err
!=
nil
{
if
err
:=
s
.
cache
.
Invalidate
(
ctx
);
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to invalidate cache: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.error_passthrough"
,
"[ErrorPassthroughService] Failed to invalidate cache: %v"
,
err
)
}
}
}
}
// 刷新本地缓存
// 刷新本地缓存
if
err
:=
s
.
reloadRulesFromDB
(
ctx
);
err
!=
nil
{
if
err
:=
s
.
reloadRulesFromDB
(
ctx
);
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to refresh local cache: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.error_passthrough"
,
"[ErrorPassthroughService] Failed to refresh local cache: %v"
,
err
)
// 刷新失败时清空本地缓存,避免继续使用陈旧规则。
// 刷新失败时清空本地缓存,避免继续使用陈旧规则。
s
.
clearLocalCache
()
s
.
clearLocalCache
()
}
}
...
@@ -302,7 +302,7 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
...
@@ -302,7 +302,7 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 通知其他实例
// 通知其他实例
if
s
.
cache
!=
nil
{
if
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
NotifyUpdate
(
ctx
);
err
!=
nil
{
if
err
:=
s
.
cache
.
NotifyUpdate
(
ctx
);
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to notify cache update: %v"
,
err
)
log
ger
.
LegacyPrintf
(
"service.error_passthrough"
,
"[ErrorPassthroughService] Failed to notify cache update: %v"
,
err
)
}
}
}
}
}
}
...
...
backend/internal/service/gateway_account_selection_test.go
0 → 100644
View file @
6bccb8a8
//go:build unit
package
service
import
(
"testing"
"time"
"github.com/stretchr/testify/require"
)
// --- helpers ---
func
testTimePtr
(
t
time
.
Time
)
*
time
.
Time
{
return
&
t
}
func
makeAccWithLoad
(
id
int64
,
priority
int
,
loadRate
int
,
lastUsed
*
time
.
Time
,
accType
string
)
accountWithLoad
{
return
accountWithLoad
{
account
:
&
Account
{
ID
:
id
,
Priority
:
priority
,
LastUsedAt
:
lastUsed
,
Type
:
accType
,
Schedulable
:
true
,
Status
:
StatusActive
,
},
loadInfo
:
&
AccountLoadInfo
{
AccountID
:
id
,
CurrentConcurrency
:
0
,
LoadRate
:
loadRate
,
},
}
}
// --- sortAccountsByPriorityAndLastUsed ---
func
TestSortAccountsByPriorityAndLastUsed_ByPriority
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
accounts
:=
[]
*
Account
{
{
ID
:
1
,
Priority
:
5
,
LastUsedAt
:
testTimePtr
(
now
)},
{
ID
:
2
,
Priority
:
1
,
LastUsedAt
:
testTimePtr
(
now
)},
{
ID
:
3
,
Priority
:
3
,
LastUsedAt
:
testTimePtr
(
now
)},
}
sortAccountsByPriorityAndLastUsed
(
accounts
,
false
)
require
.
Equal
(
t
,
int64
(
2
),
accounts
[
0
]
.
ID
,
"优先级最低的排第一"
)
require
.
Equal
(
t
,
int64
(
3
),
accounts
[
1
]
.
ID
)
require
.
Equal
(
t
,
int64
(
1
),
accounts
[
2
]
.
ID
)
}
func
TestSortAccountsByPriorityAndLastUsed_SamePriorityByLastUsed
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
accounts
:=
[]
*
Account
{
{
ID
:
1
,
Priority
:
1
,
LastUsedAt
:
testTimePtr
(
now
)},
{
ID
:
2
,
Priority
:
1
,
LastUsedAt
:
testTimePtr
(
now
.
Add
(
-
1
*
time
.
Hour
))},
{
ID
:
3
,
Priority
:
1
,
LastUsedAt
:
nil
},
}
sortAccountsByPriorityAndLastUsed
(
accounts
,
false
)
require
.
Equal
(
t
,
int64
(
3
),
accounts
[
0
]
.
ID
,
"nil LastUsedAt 排最前"
)
require
.
Equal
(
t
,
int64
(
2
),
accounts
[
1
]
.
ID
,
"更早使用的排前面"
)
require
.
Equal
(
t
,
int64
(
1
),
accounts
[
2
]
.
ID
)
}
func
TestSortAccountsByPriorityAndLastUsed_PreferOAuth
(
t
*
testing
.
T
)
{
accounts
:=
[]
*
Account
{
{
ID
:
1
,
Priority
:
1
,
LastUsedAt
:
nil
,
Type
:
AccountTypeAPIKey
},
{
ID
:
2
,
Priority
:
1
,
LastUsedAt
:
nil
,
Type
:
AccountTypeOAuth
},
}
sortAccountsByPriorityAndLastUsed
(
accounts
,
true
)
require
.
Equal
(
t
,
int64
(
2
),
accounts
[
0
]
.
ID
,
"preferOAuth 时 OAuth 账号排前面"
)
}
func
TestSortAccountsByPriorityAndLastUsed_StableSort
(
t
*
testing
.
T
)
{
accounts
:=
[]
*
Account
{
{
ID
:
1
,
Priority
:
1
,
LastUsedAt
:
nil
,
Type
:
AccountTypeAPIKey
},
{
ID
:
2
,
Priority
:
1
,
LastUsedAt
:
nil
,
Type
:
AccountTypeAPIKey
},
{
ID
:
3
,
Priority
:
1
,
LastUsedAt
:
nil
,
Type
:
AccountTypeAPIKey
},
}
// sortAccountsByPriorityAndLastUsed 内部会在同组(Priority+LastUsedAt)内做随机打散,
// 因此这里不再断言“稳定排序”。我们只验证:
// 1) 元素集合不变;2) 多次运行能产生不同的顺序。
seenFirst
:=
map
[
int64
]
bool
{}
for
i
:=
0
;
i
<
100
;
i
++
{
cpy
:=
make
([]
*
Account
,
len
(
accounts
))
copy
(
cpy
,
accounts
)
sortAccountsByPriorityAndLastUsed
(
cpy
,
false
)
seenFirst
[
cpy
[
0
]
.
ID
]
=
true
ids
:=
map
[
int64
]
bool
{}
for
_
,
a
:=
range
cpy
{
ids
[
a
.
ID
]
=
true
}
require
.
True
(
t
,
ids
[
1
]
&&
ids
[
2
]
&&
ids
[
3
])
}
require
.
GreaterOrEqual
(
t
,
len
(
seenFirst
),
2
,
"同组账号应能被随机打散"
)
}
func
TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
accounts
:=
[]
*
Account
{
{
ID
:
1
,
Priority
:
2
,
LastUsedAt
:
nil
},
{
ID
:
2
,
Priority
:
1
,
LastUsedAt
:
testTimePtr
(
now
)},
{
ID
:
3
,
Priority
:
1
,
LastUsedAt
:
testTimePtr
(
now
.
Add
(
-
1
*
time
.
Hour
))},
{
ID
:
4
,
Priority
:
2
,
LastUsedAt
:
testTimePtr
(
now
.
Add
(
-
2
*
time
.
Hour
))},
}
sortAccountsByPriorityAndLastUsed
(
accounts
,
false
)
// 优先级1排前:nil < earlier
require
.
Equal
(
t
,
int64
(
3
),
accounts
[
0
]
.
ID
,
"优先级1 + 更早"
)
require
.
Equal
(
t
,
int64
(
2
),
accounts
[
1
]
.
ID
,
"优先级1 + 现在"
)
// 优先级2排后:nil < time
require
.
Equal
(
t
,
int64
(
1
),
accounts
[
2
]
.
ID
,
"优先级2 + nil"
)
require
.
Equal
(
t
,
int64
(
4
),
accounts
[
3
]
.
ID
,
"优先级2 + 有时间"
)
}
// --- filterByMinPriority ---
func
TestFilterByMinPriority_Empty
(
t
*
testing
.
T
)
{
result
:=
filterByMinPriority
(
nil
)
require
.
Nil
(
t
,
result
)
}
func
TestFilterByMinPriority_SelectsMinPriority
(
t
*
testing
.
T
)
{
accounts
:=
[]
accountWithLoad
{
makeAccWithLoad
(
1
,
5
,
10
,
nil
,
AccountTypeAPIKey
),
makeAccWithLoad
(
2
,
1
,
10
,
nil
,
AccountTypeAPIKey
),
makeAccWithLoad
(
3
,
1
,
20
,
nil
,
AccountTypeAPIKey
),
makeAccWithLoad
(
4
,
2
,
10
,
nil
,
AccountTypeAPIKey
),
}
result
:=
filterByMinPriority
(
accounts
)
require
.
Len
(
t
,
result
,
2
)
require
.
Equal
(
t
,
int64
(
2
),
result
[
0
]
.
account
.
ID
)
require
.
Equal
(
t
,
int64
(
3
),
result
[
1
]
.
account
.
ID
)
}
// --- filterByMinLoadRate ---
func
TestFilterByMinLoadRate_Empty
(
t
*
testing
.
T
)
{
result
:=
filterByMinLoadRate
(
nil
)
require
.
Nil
(
t
,
result
)
}
func
TestFilterByMinLoadRate_SelectsMinLoadRate
(
t
*
testing
.
T
)
{
accounts
:=
[]
accountWithLoad
{
makeAccWithLoad
(
1
,
1
,
30
,
nil
,
AccountTypeAPIKey
),
makeAccWithLoad
(
2
,
1
,
10
,
nil
,
AccountTypeAPIKey
),
makeAccWithLoad
(
3
,
1
,
10
,
nil
,
AccountTypeAPIKey
),
makeAccWithLoad
(
4
,
1
,
20
,
nil
,
AccountTypeAPIKey
),
}
result
:=
filterByMinLoadRate
(
accounts
)
require
.
Len
(
t
,
result
,
2
)
require
.
Equal
(
t
,
int64
(
2
),
result
[
0
]
.
account
.
ID
)
require
.
Equal
(
t
,
int64
(
3
),
result
[
1
]
.
account
.
ID
)
}
// --- selectByLRU ---
func
TestSelectByLRU_Empty
(
t
*
testing
.
T
)
{
result
:=
selectByLRU
(
nil
,
false
)
require
.
Nil
(
t
,
result
)
}
func
TestSelectByLRU_Single
(
t
*
testing
.
T
)
{
accounts
:=
[]
accountWithLoad
{
makeAccWithLoad
(
1
,
1
,
10
,
nil
,
AccountTypeAPIKey
)}
result
:=
selectByLRU
(
accounts
,
false
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
account
.
ID
)
}
func
TestSelectByLRU_NilLastUsedAtWins
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
accounts
:=
[]
accountWithLoad
{
makeAccWithLoad
(
1
,
1
,
10
,
testTimePtr
(
now
),
AccountTypeAPIKey
),
makeAccWithLoad
(
2
,
1
,
10
,
nil
,
AccountTypeAPIKey
),
makeAccWithLoad
(
3
,
1
,
10
,
testTimePtr
(
now
.
Add
(
-
1
*
time
.
Hour
)),
AccountTypeAPIKey
),
}
result
:=
selectByLRU
(
accounts
,
false
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
account
.
ID
)
}
func
TestSelectByLRU_EarliestTimeWins
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
accounts
:=
[]
accountWithLoad
{
makeAccWithLoad
(
1
,
1
,
10
,
testTimePtr
(
now
),
AccountTypeAPIKey
),
makeAccWithLoad
(
2
,
1
,
10
,
testTimePtr
(
now
.
Add
(
-
1
*
time
.
Hour
)),
AccountTypeAPIKey
),
makeAccWithLoad
(
3
,
1
,
10
,
testTimePtr
(
now
.
Add
(
-
2
*
time
.
Hour
)),
AccountTypeAPIKey
),
}
result
:=
selectByLRU
(
accounts
,
false
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
int64
(
3
),
result
.
account
.
ID
)
}
func
TestSelectByLRU_TiePreferOAuth
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
// 账号 1/2 LastUsedAt 相同,且同为最小值。
accounts
:=
[]
accountWithLoad
{
makeAccWithLoad
(
1
,
1
,
10
,
testTimePtr
(
now
),
AccountTypeAPIKey
),
makeAccWithLoad
(
2
,
1
,
10
,
testTimePtr
(
now
),
AccountTypeOAuth
),
makeAccWithLoad
(
3
,
1
,
10
,
testTimePtr
(
now
.
Add
(
1
*
time
.
Hour
)),
AccountTypeAPIKey
),
}
for
i
:=
0
;
i
<
50
;
i
++
{
result
:=
selectByLRU
(
accounts
,
true
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
AccountTypeOAuth
,
result
.
account
.
Type
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
account
.
ID
)
}
}
backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go
0 → 100644
View file @
6bccb8a8
package
service
import
"testing"
func
BenchmarkGatewayService_ParseSSEUsage_MessageStart
(
b
*
testing
.
B
)
{
svc
:=
&
GatewayService
{}
data
:=
`{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}`
b
.
ReportAllocs
()
b
.
ResetTimer
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
usage
:=
&
ClaudeUsage
{}
svc
.
parseSSEUsage
(
data
,
usage
)
}
}
func
BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageStart
(
b
*
testing
.
B
)
{
svc
:=
&
GatewayService
{}
data
:=
`{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}`
b
.
ReportAllocs
()
b
.
ResetTimer
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
usage
:=
&
ClaudeUsage
{}
svc
.
parseSSEUsagePassthrough
(
data
,
usage
)
}
}
func
BenchmarkGatewayService_ParseSSEUsage_MessageDelta
(
b
*
testing
.
B
)
{
svc
:=
&
GatewayService
{}
data
:=
`{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}`
b
.
ReportAllocs
()
b
.
ResetTimer
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
usage
:=
&
ClaudeUsage
{}
svc
.
parseSSEUsage
(
data
,
usage
)
}
}
func
BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageDelta
(
b
*
testing
.
B
)
{
svc
:=
&
GatewayService
{}
data
:=
`{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}`
b
.
ReportAllocs
()
b
.
ResetTimer
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
usage
:=
&
ClaudeUsage
{}
svc
.
parseSSEUsagePassthrough
(
data
,
usage
)
}
}
func
BenchmarkParseClaudeUsageFromResponseBody
(
b
*
testing
.
B
)
{
body
:=
[]
byte
(
`{"id":"msg_123","type":"message","usage":{"input_tokens":123,"output_tokens":456,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}`
)
b
.
ReportAllocs
()
b
.
ResetTimer
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
_
=
parseClaudeUsageFromResponseBody
(
body
)
}
}
backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
0 → 100644
View file @
6bccb8a8
package
service
import
(
"bufio"
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
type
anthropicHTTPUpstreamRecorder
struct
{
lastReq
*
http
.
Request
lastBody
[]
byte
resp
*
http
.
Response
err
error
}
func
newAnthropicAPIKeyAccountForTest
()
*
Account
{
return
&
Account
{
ID
:
201
,
Name
:
"anthropic-apikey-pass-test"
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"upstream-anthropic-key"
,
"base_url"
:
"https://api.anthropic.com"
,
},
Extra
:
map
[
string
]
any
{
"anthropic_passthrough"
:
true
,
},
Status
:
StatusActive
,
Schedulable
:
true
,
}
}
func
(
u
*
anthropicHTTPUpstreamRecorder
)
Do
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
(
*
http
.
Response
,
error
)
{
u
.
lastReq
=
req
if
req
!=
nil
&&
req
.
Body
!=
nil
{
b
,
_
:=
io
.
ReadAll
(
req
.
Body
)
u
.
lastBody
=
b
_
=
req
.
Body
.
Close
()
req
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
b
))
}
if
u
.
err
!=
nil
{
return
nil
,
u
.
err
}
return
u
.
resp
,
nil
}
func
(
u
*
anthropicHTTPUpstreamRecorder
)
DoWithTLS
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
enableTLSFingerprint
bool
)
(
*
http
.
Response
,
error
)
{
return
u
.
Do
(
req
,
proxyURL
,
accountID
,
accountConcurrency
)
}
type
streamReadCloser
struct
{
payload
[]
byte
sent
bool
err
error
}
func
(
r
*
streamReadCloser
)
Read
(
p
[]
byte
)
(
int
,
error
)
{
if
!
r
.
sent
{
r
.
sent
=
true
n
:=
copy
(
p
,
r
.
payload
)
return
n
,
nil
}
if
r
.
err
!=
nil
{
return
0
,
r
.
err
}
return
0
,
io
.
EOF
}
func
(
r
*
streamReadCloser
)
Close
()
error
{
return
nil
}
type
failWriteResponseWriter
struct
{
gin
.
ResponseWriter
}
func
(
w
*
failWriteResponseWriter
)
Write
(
data
[]
byte
)
(
int
,
error
)
{
return
0
,
errors
.
New
(
"client disconnected"
)
}
func
(
w
*
failWriteResponseWriter
)
WriteString
(
_
string
)
(
int
,
error
)
{
return
0
,
errors
.
New
(
"client disconnected"
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAndAuthReplacement
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.0"
)
c
.
Request
.
Header
.
Set
(
"Authorization"
,
"Bearer inbound-token"
)
c
.
Request
.
Header
.
Set
(
"X-Api-Key"
,
"inbound-api-key"
)
c
.
Request
.
Header
.
Set
(
"X-Goog-Api-Key"
,
"inbound-goog-key"
)
c
.
Request
.
Header
.
Set
(
"Cookie"
,
"secret=1"
)
c
.
Request
.
Header
.
Set
(
"Anthropic-Beta"
,
"interleaved-thinking-2025-05-14"
)
body
:=
[]
byte
(
`{"model":"claude-3-7-sonnet-20250219","stream":true,"system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
)
parsed
:=
&
ParsedRequest
{
Body
:
body
,
Model
:
"claude-3-7-sonnet-20250219"
,
Stream
:
true
,
}
upstreamSSE
:=
strings
.
Join
([]
string
{
`data: {"type":"message_start","message":{"usage":{"input_tokens":9,"cached_tokens":7}}}`
,
""
,
`data: {"type":"message_delta","usage":{"output_tokens":3}}`
,
""
,
"data: [DONE]"
,
""
,
},
"
\n
"
)
upstream
:=
&
anthropicHTTPUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
},
"x-request-id"
:
[]
string
{
"rid-anthropic-pass"
},
"Set-Cookie"
:
[]
string
{
"secret=upstream"
},
},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
upstreamSSE
)),
},
}
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
},
httpUpstream
:
upstream
,
rateLimitService
:
&
RateLimitService
{},
deferredService
:
&
DeferredService
{},
billingCacheService
:
nil
,
}
account
:=
&
Account
{
ID
:
101
,
Name
:
"anthropic-apikey-pass"
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"upstream-anthropic-key"
,
"base_url"
:
"https://api.anthropic.com"
,
"model_mapping"
:
map
[
string
]
any
{
"claude-3-7-sonnet-20250219"
:
"claude-3-haiku-20240307"
},
},
Extra
:
map
[
string
]
any
{
"anthropic_passthrough"
:
true
,
},
Status
:
StatusActive
,
Schedulable
:
true
,
}
result
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
parsed
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
Stream
)
require
.
Equal
(
t
,
body
,
upstream
.
lastBody
,
"透传模式不应改写上游请求体"
)
require
.
Equal
(
t
,
"claude-3-7-sonnet-20250219"
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"model"
)
.
String
())
require
.
Equal
(
t
,
"upstream-anthropic-key"
,
upstream
.
lastReq
.
Header
.
Get
(
"x-api-key"
))
require
.
Empty
(
t
,
upstream
.
lastReq
.
Header
.
Get
(
"authorization"
))
require
.
Empty
(
t
,
upstream
.
lastReq
.
Header
.
Get
(
"x-goog-api-key"
))
require
.
Empty
(
t
,
upstream
.
lastReq
.
Header
.
Get
(
"cookie"
))
require
.
Equal
(
t
,
"2023-06-01"
,
upstream
.
lastReq
.
Header
.
Get
(
"anthropic-version"
))
require
.
Equal
(
t
,
"interleaved-thinking-2025-05-14"
,
upstream
.
lastReq
.
Header
.
Get
(
"anthropic-beta"
))
require
.
Empty
(
t
,
upstream
.
lastReq
.
Header
.
Get
(
"x-stainless-lang"
),
"API Key 透传不应注入 OAuth 指纹头"
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
`"cached_tokens":7`
)
require
.
NotContains
(
t
,
rec
.
Body
.
String
(),
`"cache_read_input_tokens":7`
,
"透传输出不应被网关改写"
)
require
.
Equal
(
t
,
7
,
result
.
Usage
.
CacheReadInputTokens
,
"计费 usage 解析应保留 cached_tokens 兼容"
)
require
.
Empty
(
t
,
rec
.
Header
()
.
Get
(
"Set-Cookie"
),
"响应头应经过安全过滤"
)
rawBody
,
ok
:=
c
.
Get
(
OpsUpstreamRequestBodyKey
)
require
.
True
(
t
,
ok
)
bodyBytes
,
ok
:=
rawBody
.
([]
byte
)
require
.
True
(
t
,
ok
,
"应以 []byte 形式缓存上游请求体,避免重复 string 拷贝"
)
require
.
Equal
(
t
,
body
,
bodyBytes
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages/count_tokens"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Authorization"
,
"Bearer inbound-token"
)
c
.
Request
.
Header
.
Set
(
"X-Api-Key"
,
"inbound-api-key"
)
c
.
Request
.
Header
.
Set
(
"Cookie"
,
"secret=1"
)
body
:=
[]
byte
(
`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"thinking":{"type":"enabled"}}`
)
parsed
:=
&
ParsedRequest
{
Body
:
body
,
Model
:
"claude-3-5-sonnet-latest"
,
}
upstreamRespBody
:=
`{"input_tokens":42}`
upstream
:=
&
anthropicHTTPUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
},
"x-request-id"
:
[]
string
{
"rid-count"
},
"Set-Cookie"
:
[]
string
{
"secret=upstream"
},
},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
upstreamRespBody
)),
},
}
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
},
httpUpstream
:
upstream
,
rateLimitService
:
&
RateLimitService
{},
}
account
:=
&
Account
{
ID
:
102
,
Name
:
"anthropic-apikey-pass-count"
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"upstream-anthropic-key"
,
"base_url"
:
"https://api.anthropic.com"
,
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-sonnet-latest"
:
"claude-3-opus-20240229"
},
},
Extra
:
map
[
string
]
any
{
"anthropic_passthrough"
:
true
,
},
Status
:
StatusActive
,
Schedulable
:
true
,
}
err
:=
svc
.
ForwardCountTokens
(
context
.
Background
(),
c
,
account
,
parsed
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
body
,
upstream
.
lastBody
,
"count_tokens 透传模式不应改写请求体"
)
require
.
Equal
(
t
,
"claude-3-5-sonnet-latest"
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"model"
)
.
String
())
require
.
Equal
(
t
,
"upstream-anthropic-key"
,
upstream
.
lastReq
.
Header
.
Get
(
"x-api-key"
))
require
.
Empty
(
t
,
upstream
.
lastReq
.
Header
.
Get
(
"authorization"
))
require
.
Empty
(
t
,
upstream
.
lastReq
.
Header
.
Get
(
"cookie"
))
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
JSONEq
(
t
,
upstreamRespBody
,
rec
.
Body
.
String
())
require
.
Empty
(
t
,
rec
.
Header
()
.
Get
(
"Set-Cookie"
))
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_BuildRequestRejectsInvalidBaseURL
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
,
},
},
},
}
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"k"
,
"base_url"
:
"://invalid-url"
,
},
}
_
,
err
:=
svc
.
buildUpstreamRequestAnthropicAPIKeyPassthrough
(
context
.
Background
(),
c
,
account
,
[]
byte
(
`{}`
),
"k"
)
require
.
Error
(
t
,
err
)
}
func
TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
},
}
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"anthropic_passthrough"
:
true
,
},
}
require
.
False
(
t
,
account
.
IsAnthropicAPIKeyPassthroughEnabled
())
req
,
err
:=
svc
.
buildUpstreamRequest
(
context
.
Background
(),
c
,
account
,
[]
byte
(
`{"model":"claude-3-7-sonnet-20250219"}`
),
"oauth-token"
,
"oauth"
,
"claude-3-7-sonnet-20250219"
,
true
,
false
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"Bearer oauth-token"
,
req
.
Header
.
Get
(
"authorization"
))
require
.
Contains
(
t
,
req
.
Header
.
Get
(
"anthropic-beta"
),
claude
.
BetaOAuth
,
"OAuth 链路仍应按原逻辑补齐 oauth beta"
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
// Use a canceled context recorder to simulate client disconnect behavior.
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
ctx
,
cancel
:=
context
.
WithCancel
(
req
.
Context
())
cancel
()
req
=
req
.
WithContext
(
ctx
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
req
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
},
rateLimitService
:
&
RateLimitService
{},
}
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
strings
.
Join
([]
string
{
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`
,
""
,
`data: {"type":"message_delta","usage":{"output_tokens":5}}`
,
""
,
"data: [DONE]"
,
""
,
},
"
\n
"
))),
}
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
usage
)
require
.
Equal
(
t
,
11
,
result
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
5
,
result
.
usage
.
OutputTokens
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
body
:=
[]
byte
(
`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
)
upstreamJSON
:=
`{"id":"msg_1","type":"message","usage":{"input_tokens":12,"output_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":3},"cached_tokens":4}}`
upstream
:=
&
anthropicHTTPUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
},
"x-request-id"
:
[]
string
{
"rid-nonstream"
},
},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
upstreamJSON
)),
},
}
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{},
httpUpstream
:
upstream
,
rateLimitService
:
&
RateLimitService
{},
}
result
,
err
:=
svc
.
forwardAnthropicAPIKeyPassthrough
(
context
.
Background
(),
c
,
newAnthropicAPIKeyAccountForTest
(),
body
,
"claude-3-5-sonnet-latest"
,
false
,
time
.
Now
())
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
12
,
result
.
Usage
.
InputTokens
)
require
.
Equal
(
t
,
7
,
result
.
Usage
.
OutputTokens
)
require
.
Equal
(
t
,
5
,
result
.
Usage
.
CacheCreationInputTokens
)
require
.
Equal
(
t
,
4
,
result
.
Usage
.
CacheReadInputTokens
)
require
.
Equal
(
t
,
upstreamJSON
,
rec
.
Body
.
String
())
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenType
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
account
:=
&
Account
{
ID
:
202
,
Name
:
"anthropic-oauth"
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"oauth-token"
,
},
}
svc
:=
&
GatewayService
{}
result
,
err
:=
svc
.
forwardAnthropicAPIKeyPassthrough
(
context
.
Background
(),
c
,
account
,
[]
byte
(
`{}`
),
"claude-3-5-sonnet-latest"
,
false
,
time
.
Now
())
require
.
Nil
(
t
,
result
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"requires apikey token"
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequestError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
upstream
:=
&
anthropicHTTPUpstreamRecorder
{
err
:
errors
.
New
(
"dial tcp timeout"
),
}
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
},
},
},
httpUpstream
:
upstream
,
}
account
:=
newAnthropicAPIKeyAccountForTest
()
result
,
err
:=
svc
.
forwardAnthropicAPIKeyPassthrough
(
context
.
Background
(),
c
,
account
,
[]
byte
(
`{"model":"x"}`
),
"x"
,
false
,
time
.
Now
())
require
.
Nil
(
t
,
result
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"upstream request failed"
)
require
.
Equal
(
t
,
http
.
StatusBadGateway
,
rec
.
Code
)
rawBody
,
ok
:=
c
.
Get
(
OpsUpstreamRequestBodyKey
)
require
.
True
(
t
,
ok
)
_
,
ok
=
rawBody
.
([]
byte
)
require
.
True
(
t
,
ok
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBody
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
upstream
:=
&
anthropicHTTPUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"x-request-id"
:
[]
string
{
"rid-empty-body"
}},
Body
:
nil
,
},
}
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
},
},
},
httpUpstream
:
upstream
,
}
result
,
err
:=
svc
.
forwardAnthropicAPIKeyPassthrough
(
context
.
Background
(),
c
,
newAnthropicAPIKeyAccountForTest
(),
[]
byte
(
`{"model":"x"}`
),
"x"
,
false
,
time
.
Now
())
require
.
Nil
(
t
,
result
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"empty response"
)
}
func
TestExtractAnthropicSSEDataLine
(
t
*
testing
.
T
)
{
t
.
Run
(
"valid data line with spaces"
,
func
(
t
*
testing
.
T
)
{
data
,
ok
:=
extractAnthropicSSEDataLine
(
"data: {
\"
type
\"
:
\"
message_start
\"
}"
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
`{"type":"message_start"}`
,
data
)
})
t
.
Run
(
"non data line"
,
func
(
t
*
testing
.
T
)
{
data
,
ok
:=
extractAnthropicSSEDataLine
(
"event: message_start"
)
require
.
False
(
t
,
ok
)
require
.
Empty
(
t
,
data
)
})
}
func
TestGatewayService_ParseSSEUsagePassthrough_MessageStartFallbacks
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
usage
:=
&
ClaudeUsage
{}
data
:=
`{"type":"message_start","message":{"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":9,"cache_creation":{"ephemeral_5m_input_tokens":3,"ephemeral_1h_input_tokens":4}}}}`
svc
.
parseSSEUsagePassthrough
(
data
,
usage
)
require
.
Equal
(
t
,
12
,
usage
.
InputTokens
)
require
.
Equal
(
t
,
9
,
usage
.
CacheReadInputTokens
,
"应兼容 cached_tokens 字段"
)
require
.
Equal
(
t
,
7
,
usage
.
CacheCreationInputTokens
,
"聚合字段为空时应从 5m/1h 明细回填"
)
require
.
Equal
(
t
,
3
,
usage
.
CacheCreation5mTokens
)
require
.
Equal
(
t
,
4
,
usage
.
CacheCreation1hTokens
)
}
func
TestGatewayService_ParseSSEUsagePassthrough_MessageDeltaSelectiveOverwrite
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
usage
:=
&
ClaudeUsage
{
InputTokens
:
10
,
CacheCreation5mTokens
:
2
,
CacheCreation1hTokens
:
6
,
}
data
:=
`{"type":"message_delta","usage":{"input_tokens":0,"output_tokens":5,"cache_creation_input_tokens":8,"cache_read_input_tokens":0,"cached_tokens":11,"cache_creation":{"ephemeral_5m_input_tokens":1,"ephemeral_1h_input_tokens":0}}}`
svc
.
parseSSEUsagePassthrough
(
data
,
usage
)
require
.
Equal
(
t
,
10
,
usage
.
InputTokens
,
"message_delta 中 0 值不应覆盖已有 input_tokens"
)
require
.
Equal
(
t
,
5
,
usage
.
OutputTokens
)
require
.
Equal
(
t
,
8
,
usage
.
CacheCreationInputTokens
)
require
.
Equal
(
t
,
11
,
usage
.
CacheReadInputTokens
,
"cache_read_input_tokens 为空时应回退到 cached_tokens"
)
require
.
Equal
(
t
,
1
,
usage
.
CacheCreation5mTokens
)
require
.
Equal
(
t
,
6
,
usage
.
CacheCreation1hTokens
,
"message_delta 中 0 值不应覆盖已有 1h 明细"
)
}
func
TestGatewayService_ParseSSEUsagePassthrough_NoopCases
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
usage
:=
&
ClaudeUsage
{
InputTokens
:
3
}
svc
.
parseSSEUsagePassthrough
(
""
,
usage
)
require
.
Equal
(
t
,
3
,
usage
.
InputTokens
)
svc
.
parseSSEUsagePassthrough
(
"[DONE]"
,
usage
)
require
.
Equal
(
t
,
3
,
usage
.
InputTokens
)
svc
.
parseSSEUsagePassthrough
(
"not-json"
,
usage
)
require
.
Equal
(
t
,
3
,
usage
.
InputTokens
)
// nil usage 不应 panic
svc
.
parseSSEUsagePassthrough
(
`{"type":"message_start"}`
,
nil
)
}
func
TestGatewayService_ParseSSEUsagePassthrough_FallbackFromUsageNode
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
usage
:=
&
ClaudeUsage
{}
data
:=
`{"type":"content_block_delta","usage":{"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":1}}}`
svc
.
parseSSEUsagePassthrough
(
data
,
usage
)
require
.
Equal
(
t
,
6
,
usage
.
CacheReadInputTokens
)
require
.
Equal
(
t
,
3
,
usage
.
CacheCreationInputTokens
)
}
func
TestParseClaudeUsageFromResponseBody
(
t
*
testing
.
T
)
{
t
.
Run
(
"empty or missing usage"
,
func
(
t
*
testing
.
T
)
{
got
:=
parseClaudeUsageFromResponseBody
(
nil
)
require
.
NotNil
(
t
,
got
)
require
.
Equal
(
t
,
0
,
got
.
InputTokens
)
got
=
parseClaudeUsageFromResponseBody
([]
byte
(
`{"id":"x"}`
))
require
.
NotNil
(
t
,
got
)
require
.
Equal
(
t
,
0
,
got
.
OutputTokens
)
})
t
.
Run
(
"parse all usage fields and fallback"
,
func
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"usage":{"input_tokens":21,"output_tokens":34,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":13,"cache_creation":{"ephemeral_5m_input_tokens":5,"ephemeral_1h_input_tokens":8}}}`
)
got
:=
parseClaudeUsageFromResponseBody
(
body
)
require
.
Equal
(
t
,
21
,
got
.
InputTokens
)
require
.
Equal
(
t
,
34
,
got
.
OutputTokens
)
require
.
Equal
(
t
,
13
,
got
.
CacheReadInputTokens
,
"cache_read_input_tokens 为空时应回退 cached_tokens"
)
require
.
Equal
(
t
,
13
,
got
.
CacheCreationInputTokens
,
"聚合字段为空时应由 5m/1h 回填"
)
require
.
Equal
(
t
,
5
,
got
.
CacheCreation5mTokens
)
require
.
Equal
(
t
,
8
,
got
.
CacheCreation1hTokens
)
})
t
.
Run
(
"keep explicit aggregate values"
,
func
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"usage":{"input_tokens":1,"output_tokens":2,"cache_creation_input_tokens":9,"cache_read_input_tokens":7,"cached_tokens":99,"cache_creation":{"ephemeral_5m_input_tokens":4,"ephemeral_1h_input_tokens":5}}}`
)
got
:=
parseClaudeUsageFromResponseBody
(
body
)
require
.
Equal
(
t
,
9
,
got
.
CacheCreationInputTokens
,
"已显式提供聚合字段时不应被明细覆盖"
)
require
.
Equal
(
t
,
7
,
got
.
CacheReadInputTokens
,
"已显式提供 cache_read_input_tokens 时不应回退 cached_tokens"
)
})
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_StreamingErrTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
32
,
},
},
}
// Scanner 初始缓冲为 64KB,构造更长单行触发 bufio.ErrTooLong。
longLine
:=
"data: "
+
strings
.
Repeat
(
"x"
,
80
*
1024
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
longLine
)),
}
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
2
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
require
.
Error
(
t
,
err
)
require
.
ErrorIs
(
t
,
err
,
bufio
.
ErrTooLong
)
require
.
NotNil
(
t
,
result
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
1
,
MaxLineSize
:
defaultMaxLineSize
,
},
},
rateLimitService
:
&
RateLimitService
{},
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
pr
,
}
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
5
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
_
=
pw
.
Close
()
_
=
pr
.
Close
()
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"stream data interval timeout"
)
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
},
}
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
&
streamReadCloser
{
err
:
io
.
ErrUnexpectedEOF
,
},
}
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
6
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"stream read error"
)
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDisconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
c
.
Writer
=
&
failWriteResponseWriter
{
ResponseWriter
:
c
.
Writer
}
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
1
,
MaxLineSize
:
defaultMaxLineSize
,
},
},
rateLimitService
:
&
RateLimitService
{},
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
pr
,
}
done
:=
make
(
chan
struct
{})
go
func
()
{
defer
close
(
done
)
_
,
_
=
pw
.
Write
([]
byte
(
`data: {"type":"message_start","message":{"usage":{"input_tokens":9}}}`
+
"
\n
"
))
// 保持上游连接静默,触发数据间隔超时分支。
time
.
Sleep
(
1500
*
time
.
Millisecond
)
_
=
pw
.
Close
()
}()
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
7
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
_
=
pr
.
Close
()
<-
done
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
Equal
(
t
,
9
,
result
.
usage
.
InputTokens
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
},
}
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
&
streamReadCloser
{
err
:
context
.
Canceled
,
},
}
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
3
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAfterClientDisconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
c
.
Writer
=
&
failWriteResponseWriter
{
ResponseWriter
:
c
.
Writer
}
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
},
}
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
&
streamReadCloser
{
payload
:
[]
byte
(
`data: {"type":"message_start","message":{"usage":{"input_tokens":8}}}`
+
"
\n\n
"
),
err
:
io
.
ErrUnexpectedEOF
,
},
}
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
4
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
Equal
(
t
,
8
,
result
.
usage
.
InputTokens
)
}
backend/internal/service/gateway_hotpath_optimization_test.go
0 → 100644
View file @
6bccb8a8
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
gocache
"github.com/patrickmn/go-cache"
"github.com/stretchr/testify/require"
)
type
userGroupRateRepoHotpathStub
struct
{
UserGroupRateRepository
rate
*
float64
err
error
wait
<-
chan
struct
{}
calls
atomic
.
Int64
}
func
(
s
*
userGroupRateRepoHotpathStub
)
GetByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
float64
,
error
)
{
s
.
calls
.
Add
(
1
)
if
s
.
wait
!=
nil
{
<-
s
.
wait
}
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
return
s
.
rate
,
nil
}
type
usageLogWindowBatchRepoStub
struct
{
UsageLogRepository
batchResult
map
[
int64
]
*
usagestats
.
AccountStats
batchErr
error
batchCalls
atomic
.
Int64
singleResult
map
[
int64
]
*
usagestats
.
AccountStats
singleErr
error
singleCalls
atomic
.
Int64
}
func
(
s
*
usageLogWindowBatchRepoStub
)
GetAccountWindowStatsBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
startTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
AccountStats
,
error
)
{
s
.
batchCalls
.
Add
(
1
)
if
s
.
batchErr
!=
nil
{
return
nil
,
s
.
batchErr
}
out
:=
make
(
map
[
int64
]
*
usagestats
.
AccountStats
,
len
(
accountIDs
))
for
_
,
id
:=
range
accountIDs
{
if
stats
,
ok
:=
s
.
batchResult
[
id
];
ok
{
out
[
id
]
=
stats
}
}
return
out
,
nil
}
func
(
s
*
usageLogWindowBatchRepoStub
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
s
.
singleCalls
.
Add
(
1
)
if
s
.
singleErr
!=
nil
{
return
nil
,
s
.
singleErr
}
if
stats
,
ok
:=
s
.
singleResult
[
accountID
];
ok
{
return
stats
,
nil
}
return
&
usagestats
.
AccountStats
{},
nil
}
type
sessionLimitCacheHotpathStub
struct
{
SessionLimitCache
batchData
map
[
int64
]
float64
batchErr
error
setData
map
[
int64
]
float64
setErr
error
}
func
(
s
*
sessionLimitCacheHotpathStub
)
GetWindowCostBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
float64
,
error
)
{
if
s
.
batchErr
!=
nil
{
return
nil
,
s
.
batchErr
}
out
:=
make
(
map
[
int64
]
float64
,
len
(
accountIDs
))
for
_
,
id
:=
range
accountIDs
{
if
v
,
ok
:=
s
.
batchData
[
id
];
ok
{
out
[
id
]
=
v
}
}
return
out
,
nil
}
func
(
s
*
sessionLimitCacheHotpathStub
)
SetWindowCost
(
ctx
context
.
Context
,
accountID
int64
,
cost
float64
)
error
{
if
s
.
setErr
!=
nil
{
return
s
.
setErr
}
if
s
.
setData
==
nil
{
s
.
setData
=
make
(
map
[
int64
]
float64
)
}
s
.
setData
[
accountID
]
=
cost
return
nil
}
type
modelsListAccountRepoStub
struct
{
AccountRepository
byGroup
map
[
int64
][]
Account
all
[]
Account
err
error
listByGroupCalls
atomic
.
Int64
listAllCalls
atomic
.
Int64
}
type
stickyGatewayCacheHotpathStub
struct
{
GatewayCache
stickyID
int64
getCalls
atomic
.
Int64
}
func
(
s
*
stickyGatewayCacheHotpathStub
)
GetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
(
int64
,
error
)
{
s
.
getCalls
.
Add
(
1
)
if
s
.
stickyID
>
0
{
return
s
.
stickyID
,
nil
}
return
0
,
errors
.
New
(
"not found"
)
}
func
(
s
*
stickyGatewayCacheHotpathStub
)
SetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
stickyGatewayCacheHotpathStub
)
RefreshSessionTTL
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
stickyGatewayCacheHotpathStub
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
return
nil
}
func
(
s
*
modelsListAccountRepoStub
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
s
.
listByGroupCalls
.
Add
(
1
)
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
accounts
,
ok
:=
s
.
byGroup
[
groupID
]
if
!
ok
{
return
nil
,
nil
}
out
:=
make
([]
Account
,
len
(
accounts
))
copy
(
out
,
accounts
)
return
out
,
nil
}
func
(
s
*
modelsListAccountRepoStub
)
ListSchedulable
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
s
.
listAllCalls
.
Add
(
1
)
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
out
:=
make
([]
Account
,
len
(
s
.
all
))
copy
(
out
,
s
.
all
)
return
out
,
nil
}
func
resetGatewayHotpathStatsForTest
()
{
windowCostPrefetchCacheHitTotal
.
Store
(
0
)
windowCostPrefetchCacheMissTotal
.
Store
(
0
)
windowCostPrefetchBatchSQLTotal
.
Store
(
0
)
windowCostPrefetchFallbackTotal
.
Store
(
0
)
windowCostPrefetchErrorTotal
.
Store
(
0
)
userGroupRateCacheHitTotal
.
Store
(
0
)
userGroupRateCacheMissTotal
.
Store
(
0
)
userGroupRateCacheLoadTotal
.
Store
(
0
)
userGroupRateCacheSFSharedTotal
.
Store
(
0
)
userGroupRateCacheFallbackTotal
.
Store
(
0
)
modelsListCacheHitTotal
.
Store
(
0
)
modelsListCacheMissTotal
.
Store
(
0
)
modelsListCacheStoreTotal
.
Store
(
0
)
}
func
TestGetUserGroupRateMultiplier_UsesCacheAndSingleflight
(
t
*
testing
.
T
)
{
resetGatewayHotpathStatsForTest
()
rate
:=
1.7
unblock
:=
make
(
chan
struct
{})
repo
:=
&
userGroupRateRepoHotpathStub
{
rate
:
&
rate
,
wait
:
unblock
,
}
svc
:=
&
GatewayService
{
userGroupRateRepo
:
repo
,
userGroupRateCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
UserGroupRateCacheTTLSeconds
:
30
,
},
},
}
const
concurrent
=
12
results
:=
make
([]
float64
,
concurrent
)
start
:=
make
(
chan
struct
{})
var
wg
sync
.
WaitGroup
wg
.
Add
(
concurrent
)
for
i
:=
0
;
i
<
concurrent
;
i
++
{
go
func
(
idx
int
)
{
defer
wg
.
Done
()
<-
start
results
[
idx
]
=
svc
.
getUserGroupRateMultiplier
(
context
.
Background
(),
101
,
202
,
1.2
)
}(
i
)
}
close
(
start
)
time
.
Sleep
(
20
*
time
.
Millisecond
)
close
(
unblock
)
wg
.
Wait
()
for
_
,
got
:=
range
results
{
require
.
Equal
(
t
,
rate
,
got
)
}
require
.
Equal
(
t
,
int64
(
1
),
repo
.
calls
.
Load
())
// 再次读取应命中缓存,不再回源。
got
:=
svc
.
getUserGroupRateMultiplier
(
context
.
Background
(),
101
,
202
,
1.2
)
require
.
Equal
(
t
,
rate
,
got
)
require
.
Equal
(
t
,
int64
(
1
),
repo
.
calls
.
Load
())
hit
,
miss
,
load
,
sfShared
,
fallback
:=
GatewayUserGroupRateCacheStats
()
require
.
GreaterOrEqual
(
t
,
hit
,
int64
(
1
))
require
.
Equal
(
t
,
int64
(
12
),
miss
)
require
.
Equal
(
t
,
int64
(
1
),
load
)
require
.
GreaterOrEqual
(
t
,
sfShared
,
int64
(
1
))
require
.
Equal
(
t
,
int64
(
0
),
fallback
)
}
func
TestGetUserGroupRateMultiplier_FallbackOnRepoError
(
t
*
testing
.
T
)
{
resetGatewayHotpathStatsForTest
()
repo
:=
&
userGroupRateRepoHotpathStub
{
err
:
errors
.
New
(
"db down"
),
}
svc
:=
&
GatewayService
{
userGroupRateRepo
:
repo
,
userGroupRateCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
UserGroupRateCacheTTLSeconds
:
30
,
},
},
}
got
:=
svc
.
getUserGroupRateMultiplier
(
context
.
Background
(),
101
,
202
,
1.25
)
require
.
Equal
(
t
,
1.25
,
got
)
require
.
Equal
(
t
,
int64
(
1
),
repo
.
calls
.
Load
())
_
,
_
,
_
,
_
,
fallback
:=
GatewayUserGroupRateCacheStats
()
require
.
Equal
(
t
,
int64
(
1
),
fallback
)
}
func
TestGetUserGroupRateMultiplier_CacheHitAndNilRepo
(
t
*
testing
.
T
)
{
resetGatewayHotpathStatsForTest
()
repo
:=
&
userGroupRateRepoHotpathStub
{
err
:
errors
.
New
(
"should not be called"
),
}
svc
:=
&
GatewayService
{
userGroupRateRepo
:
repo
,
userGroupRateCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
}
key
:=
"101:202"
svc
.
userGroupRateCache
.
Set
(
key
,
2.3
,
time
.
Minute
)
got
:=
svc
.
getUserGroupRateMultiplier
(
context
.
Background
(),
101
,
202
,
1.1
)
require
.
Equal
(
t
,
2.3
,
got
)
hit
,
miss
,
load
,
_
,
fallback
:=
GatewayUserGroupRateCacheStats
()
require
.
Equal
(
t
,
int64
(
1
),
hit
)
require
.
Equal
(
t
,
int64
(
0
),
miss
)
require
.
Equal
(
t
,
int64
(
0
),
load
)
require
.
Equal
(
t
,
int64
(
0
),
fallback
)
require
.
Equal
(
t
,
int64
(
0
),
repo
.
calls
.
Load
())
// 无 repo 时直接返回分组默认倍率
svc2
:=
&
GatewayService
{
userGroupRateCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
}
svc2
.
userGroupRateCache
.
Set
(
key
,
1.9
,
time
.
Minute
)
require
.
Equal
(
t
,
1.9
,
svc2
.
getUserGroupRateMultiplier
(
context
.
Background
(),
101
,
202
,
1.4
))
require
.
Equal
(
t
,
1.4
,
svc2
.
getUserGroupRateMultiplier
(
context
.
Background
(),
0
,
202
,
1.4
))
svc2
.
userGroupRateCache
.
Delete
(
key
)
require
.
Equal
(
t
,
1.4
,
svc2
.
getUserGroupRateMultiplier
(
context
.
Background
(),
101
,
202
,
1.4
))
}
func
TestWithWindowCostPrefetch_BatchReadAndContextReuse
(
t
*
testing
.
T
)
{
resetGatewayHotpathStatsForTest
()
windowStart
:=
time
.
Now
()
.
Add
(
-
30
*
time
.
Minute
)
.
Truncate
(
time
.
Hour
)
windowEnd
:=
windowStart
.
Add
(
5
*
time
.
Hour
)
accounts
:=
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"window_cost_limit"
:
100.0
},
SessionWindowStart
:
&
windowStart
,
SessionWindowEnd
:
&
windowEnd
,
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeSetupToken
,
Extra
:
map
[
string
]
any
{
"window_cost_limit"
:
100.0
},
SessionWindowStart
:
&
windowStart
,
SessionWindowEnd
:
&
windowEnd
,
},
{
ID
:
3
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"window_cost_limit"
:
100.0
},
},
}
cache
:=
&
sessionLimitCacheHotpathStub
{
batchData
:
map
[
int64
]
float64
{
1
:
11.0
,
},
}
repo
:=
&
usageLogWindowBatchRepoStub
{
batchResult
:
map
[
int64
]
*
usagestats
.
AccountStats
{
2
:
{
StandardCost
:
22.0
},
},
}
svc
:=
&
GatewayService
{
sessionLimitCache
:
cache
,
usageLogRepo
:
repo
,
}
outCtx
:=
svc
.
withWindowCostPrefetch
(
context
.
Background
(),
accounts
)
require
.
NotNil
(
t
,
outCtx
)
cost1
,
ok1
:=
windowCostFromPrefetchContext
(
outCtx
,
1
)
require
.
True
(
t
,
ok1
)
require
.
Equal
(
t
,
11.0
,
cost1
)
cost2
,
ok2
:=
windowCostFromPrefetchContext
(
outCtx
,
2
)
require
.
True
(
t
,
ok2
)
require
.
Equal
(
t
,
22.0
,
cost2
)
_
,
ok3
:=
windowCostFromPrefetchContext
(
outCtx
,
3
)
require
.
False
(
t
,
ok3
)
require
.
Equal
(
t
,
int64
(
1
),
repo
.
batchCalls
.
Load
())
require
.
Equal
(
t
,
22.0
,
cache
.
setData
[
2
])
hit
,
miss
,
batchSQL
,
fallback
,
errCount
:=
GatewayWindowCostPrefetchStats
()
require
.
Equal
(
t
,
int64
(
1
),
hit
)
require
.
Equal
(
t
,
int64
(
1
),
miss
)
require
.
Equal
(
t
,
int64
(
1
),
batchSQL
)
require
.
Equal
(
t
,
int64
(
0
),
fallback
)
require
.
Equal
(
t
,
int64
(
0
),
errCount
)
}
func
TestWithWindowCostPrefetch_AllHitNoSQL
(
t
*
testing
.
T
)
{
resetGatewayHotpathStatsForTest
()
windowStart
:=
time
.
Now
()
.
Add
(
-
30
*
time
.
Minute
)
.
Truncate
(
time
.
Hour
)
windowEnd
:=
windowStart
.
Add
(
5
*
time
.
Hour
)
accounts
:=
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"window_cost_limit"
:
100.0
},
SessionWindowStart
:
&
windowStart
,
SessionWindowEnd
:
&
windowEnd
,
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeSetupToken
,
Extra
:
map
[
string
]
any
{
"window_cost_limit"
:
100.0
},
SessionWindowStart
:
&
windowStart
,
SessionWindowEnd
:
&
windowEnd
,
},
}
cache
:=
&
sessionLimitCacheHotpathStub
{
batchData
:
map
[
int64
]
float64
{
1
:
11.0
,
2
:
22.0
,
},
}
repo
:=
&
usageLogWindowBatchRepoStub
{}
svc
:=
&
GatewayService
{
sessionLimitCache
:
cache
,
usageLogRepo
:
repo
,
}
outCtx
:=
svc
.
withWindowCostPrefetch
(
context
.
Background
(),
accounts
)
cost1
,
ok1
:=
windowCostFromPrefetchContext
(
outCtx
,
1
)
cost2
,
ok2
:=
windowCostFromPrefetchContext
(
outCtx
,
2
)
require
.
True
(
t
,
ok1
)
require
.
True
(
t
,
ok2
)
require
.
Equal
(
t
,
11.0
,
cost1
)
require
.
Equal
(
t
,
22.0
,
cost2
)
require
.
Equal
(
t
,
int64
(
0
),
repo
.
batchCalls
.
Load
())
require
.
Equal
(
t
,
int64
(
0
),
repo
.
singleCalls
.
Load
())
hit
,
miss
,
batchSQL
,
fallback
,
errCount
:=
GatewayWindowCostPrefetchStats
()
require
.
Equal
(
t
,
int64
(
2
),
hit
)
require
.
Equal
(
t
,
int64
(
0
),
miss
)
require
.
Equal
(
t
,
int64
(
0
),
batchSQL
)
require
.
Equal
(
t
,
int64
(
0
),
fallback
)
require
.
Equal
(
t
,
int64
(
0
),
errCount
)
}
func
TestWithWindowCostPrefetch_BatchErrorFallbackSingleQuery
(
t
*
testing
.
T
)
{
resetGatewayHotpathStatsForTest
()
windowStart
:=
time
.
Now
()
.
Add
(
-
30
*
time
.
Minute
)
.
Truncate
(
time
.
Hour
)
windowEnd
:=
windowStart
.
Add
(
5
*
time
.
Hour
)
accounts
:=
[]
Account
{
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeSetupToken
,
Extra
:
map
[
string
]
any
{
"window_cost_limit"
:
100.0
},
SessionWindowStart
:
&
windowStart
,
SessionWindowEnd
:
&
windowEnd
,
},
}
cache
:=
&
sessionLimitCacheHotpathStub
{}
repo
:=
&
usageLogWindowBatchRepoStub
{
batchErr
:
errors
.
New
(
"batch failed"
),
singleResult
:
map
[
int64
]
*
usagestats
.
AccountStats
{
2
:
{
StandardCost
:
33.0
},
},
}
svc
:=
&
GatewayService
{
sessionLimitCache
:
cache
,
usageLogRepo
:
repo
,
}
outCtx
:=
svc
.
withWindowCostPrefetch
(
context
.
Background
(),
accounts
)
cost
,
ok
:=
windowCostFromPrefetchContext
(
outCtx
,
2
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
33.0
,
cost
)
require
.
Equal
(
t
,
int64
(
1
),
repo
.
batchCalls
.
Load
())
require
.
Equal
(
t
,
int64
(
1
),
repo
.
singleCalls
.
Load
())
_
,
_
,
_
,
fallback
,
errCount
:=
GatewayWindowCostPrefetchStats
()
require
.
Equal
(
t
,
int64
(
1
),
fallback
)
require
.
Equal
(
t
,
int64
(
1
),
errCount
)
}
func
TestGetAvailableModels_UsesShortCacheAndSupportsInvalidation
(
t
*
testing
.
T
)
{
resetGatewayHotpathStatsForTest
()
groupID
:=
int64
(
9
)
repo
:=
&
modelsListAccountRepoStub
{
byGroup
:
map
[
int64
][]
Account
{
groupID
:
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-sonnet"
:
"claude-3-5-sonnet"
,
"claude-3-5-haiku"
:
"claude-3-5-haiku"
,
},
},
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-2.5-pro"
:
"gemini-2.5-pro"
,
},
},
},
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
modelsListCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCacheTTL
:
time
.
Minute
,
}
models1
:=
svc
.
GetAvailableModels
(
context
.
Background
(),
&
groupID
,
PlatformAnthropic
)
require
.
Equal
(
t
,
[]
string
{
"claude-3-5-haiku"
,
"claude-3-5-sonnet"
},
models1
)
require
.
Equal
(
t
,
int64
(
1
),
repo
.
listByGroupCalls
.
Load
())
// TTL 内再次请求应命中缓存,不回源。
models2
:=
svc
.
GetAvailableModels
(
context
.
Background
(),
&
groupID
,
PlatformAnthropic
)
require
.
Equal
(
t
,
models1
,
models2
)
require
.
Equal
(
t
,
int64
(
1
),
repo
.
listByGroupCalls
.
Load
())
// 更新仓储数据,但缓存未失效前应继续返回旧值。
repo
.
byGroup
[
groupID
]
=
[]
Account
{
{
ID
:
3
,
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-7-sonnet"
:
"claude-3-7-sonnet"
,
},
},
},
}
models3
:=
svc
.
GetAvailableModels
(
context
.
Background
(),
&
groupID
,
PlatformAnthropic
)
require
.
Equal
(
t
,
[]
string
{
"claude-3-5-haiku"
,
"claude-3-5-sonnet"
},
models3
)
require
.
Equal
(
t
,
int64
(
1
),
repo
.
listByGroupCalls
.
Load
())
svc
.
InvalidateAvailableModelsCache
(
&
groupID
,
PlatformAnthropic
)
models4
:=
svc
.
GetAvailableModels
(
context
.
Background
(),
&
groupID
,
PlatformAnthropic
)
require
.
Equal
(
t
,
[]
string
{
"claude-3-7-sonnet"
},
models4
)
require
.
Equal
(
t
,
int64
(
2
),
repo
.
listByGroupCalls
.
Load
())
hit
,
miss
,
store
:=
GatewayModelsListCacheStats
()
require
.
Equal
(
t
,
int64
(
2
),
hit
)
require
.
Equal
(
t
,
int64
(
2
),
miss
)
require
.
Equal
(
t
,
int64
(
2
),
store
)
}
func
TestGetAvailableModels_ErrorAndGlobalListBranches
(
t
*
testing
.
T
)
{
resetGatewayHotpathStatsForTest
()
errRepo
:=
&
modelsListAccountRepoStub
{
err
:
errors
.
New
(
"db error"
),
}
svcErr
:=
&
GatewayService
{
accountRepo
:
errRepo
,
modelsListCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCacheTTL
:
time
.
Minute
,
}
require
.
Nil
(
t
,
svcErr
.
GetAvailableModels
(
context
.
Background
(),
nil
,
""
))
okRepo
:=
&
modelsListAccountRepoStub
{
all
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-sonnet"
:
"claude-3-5-sonnet"
,
},
},
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-2.5-pro"
:
"gemini-2.5-pro"
,
},
},
},
},
}
svcOK
:=
&
GatewayService
{
accountRepo
:
okRepo
,
modelsListCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCacheTTL
:
time
.
Minute
,
}
models
:=
svcOK
.
GetAvailableModels
(
context
.
Background
(),
nil
,
""
)
require
.
Equal
(
t
,
[]
string
{
"claude-3-5-sonnet"
,
"gemini-2.5-pro"
},
models
)
require
.
Equal
(
t
,
int64
(
1
),
okRepo
.
listAllCalls
.
Load
())
}
func
TestGatewayHotpathHelpers_CacheTTLAndStickyContext
(
t
*
testing
.
T
)
{
t
.
Run
(
"resolve_user_group_rate_cache_ttl"
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
defaultUserGroupRateCacheTTL
,
resolveUserGroupRateCacheTTL
(
nil
))
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
UserGroupRateCacheTTLSeconds
:
45
,
},
}
require
.
Equal
(
t
,
45
*
time
.
Second
,
resolveUserGroupRateCacheTTL
(
cfg
))
})
t
.
Run
(
"resolve_models_list_cache_ttl"
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
defaultModelsListCacheTTL
,
resolveModelsListCacheTTL
(
nil
))
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
ModelsListCacheTTLSeconds
:
20
,
},
}
require
.
Equal
(
t
,
20
*
time
.
Second
,
resolveModelsListCacheTTL
(
cfg
))
})
t
.
Run
(
"prefetched_sticky_account_id_from_context"
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
context
.
TODO
(),
nil
))
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
context
.
Background
(),
nil
))
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
int64
(
123
))
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyGroupID
,
int64
(
0
))
require
.
Equal
(
t
,
int64
(
123
),
prefetchedStickyAccountIDFromContext
(
ctx
,
nil
))
groupID
:=
int64
(
9
)
ctx2
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
456
)
ctx2
=
context
.
WithValue
(
ctx2
,
ctxkey
.
PrefetchedStickyGroupID
,
groupID
)
require
.
Equal
(
t
,
int64
(
456
),
prefetchedStickyAccountIDFromContext
(
ctx2
,
&
groupID
))
ctx3
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
"invalid"
)
ctx3
=
context
.
WithValue
(
ctx3
,
ctxkey
.
PrefetchedStickyGroupID
,
groupID
)
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
ctx3
,
&
groupID
))
ctx4
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
int64
(
789
))
ctx4
=
context
.
WithValue
(
ctx4
,
ctxkey
.
PrefetchedStickyGroupID
,
int64
(
10
))
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
ctx4
,
&
groupID
))
})
t
.
Run
(
"window_cost_from_prefetch_context"
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
false
,
func
()
bool
{
_
,
ok
:=
windowCostFromPrefetchContext
(
context
.
TODO
(),
0
)
return
ok
}())
require
.
Equal
(
t
,
false
,
func
()
bool
{
_
,
ok
:=
windowCostFromPrefetchContext
(
context
.
Background
(),
1
)
return
ok
}())
ctx
:=
context
.
WithValue
(
context
.
Background
(),
windowCostPrefetchContextKey
,
map
[
int64
]
float64
{
9
:
12.34
,
})
cost
,
ok
:=
windowCostFromPrefetchContext
(
ctx
,
9
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
12.34
,
cost
)
})
}
func
TestInvalidateAvailableModelsCache_ByDimensions
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{
modelsListCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
}
group9
:=
int64
(
9
)
group10
:=
int64
(
10
)
svc
.
modelsListCache
.
Set
(
modelsListCacheKey
(
&
group9
,
PlatformAnthropic
),
[]
string
{
"a"
},
time
.
Minute
)
svc
.
modelsListCache
.
Set
(
modelsListCacheKey
(
&
group9
,
PlatformGemini
),
[]
string
{
"b"
},
time
.
Minute
)
svc
.
modelsListCache
.
Set
(
modelsListCacheKey
(
&
group10
,
PlatformAnthropic
),
[]
string
{
"c"
},
time
.
Minute
)
svc
.
modelsListCache
.
Set
(
"invalid-key"
,
[]
string
{
"d"
},
time
.
Minute
)
t
.
Run
(
"invalidate_group_and_platform"
,
func
(
t
*
testing
.
T
)
{
svc
.
InvalidateAvailableModelsCache
(
&
group9
,
PlatformAnthropic
)
_
,
found
:=
svc
.
modelsListCache
.
Get
(
modelsListCacheKey
(
&
group9
,
PlatformAnthropic
))
require
.
False
(
t
,
found
)
_
,
stillFound
:=
svc
.
modelsListCache
.
Get
(
modelsListCacheKey
(
&
group9
,
PlatformGemini
))
require
.
True
(
t
,
stillFound
)
})
t
.
Run
(
"invalidate_group_only"
,
func
(
t
*
testing
.
T
)
{
svc
.
InvalidateAvailableModelsCache
(
&
group9
,
""
)
_
,
foundA
:=
svc
.
modelsListCache
.
Get
(
modelsListCacheKey
(
&
group9
,
PlatformAnthropic
))
_
,
foundB
:=
svc
.
modelsListCache
.
Get
(
modelsListCacheKey
(
&
group9
,
PlatformGemini
))
require
.
False
(
t
,
foundA
)
require
.
False
(
t
,
foundB
)
_
,
foundOtherGroup
:=
svc
.
modelsListCache
.
Get
(
modelsListCacheKey
(
&
group10
,
PlatformAnthropic
))
require
.
True
(
t
,
foundOtherGroup
)
})
t
.
Run
(
"invalidate_platform_only"
,
func
(
t
*
testing
.
T
)
{
// 重建数据后仅按 platform 失效
svc
.
modelsListCache
.
Set
(
modelsListCacheKey
(
&
group9
,
PlatformAnthropic
),
[]
string
{
"a"
},
time
.
Minute
)
svc
.
modelsListCache
.
Set
(
modelsListCacheKey
(
&
group9
,
PlatformGemini
),
[]
string
{
"b"
},
time
.
Minute
)
svc
.
modelsListCache
.
Set
(
modelsListCacheKey
(
&
group10
,
PlatformAnthropic
),
[]
string
{
"c"
},
time
.
Minute
)
svc
.
InvalidateAvailableModelsCache
(
nil
,
PlatformAnthropic
)
_
,
found9Anthropic
:=
svc
.
modelsListCache
.
Get
(
modelsListCacheKey
(
&
group9
,
PlatformAnthropic
))
_
,
found10Anthropic
:=
svc
.
modelsListCache
.
Get
(
modelsListCacheKey
(
&
group10
,
PlatformAnthropic
))
_
,
found9Gemini
:=
svc
.
modelsListCache
.
Get
(
modelsListCacheKey
(
&
group9
,
PlatformGemini
))
require
.
False
(
t
,
found9Anthropic
)
require
.
False
(
t
,
found10Anthropic
)
require
.
True
(
t
,
found9Gemini
)
})
}
func
TestSelectAccountWithLoadAwareness_StickyReadReuse
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
.
Add
(
-
time
.
Minute
)
account
:=
Account
{
ID
:
88
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
4
,
Priority
:
1
,
LastUsedAt
:
&
now
,
}
repo
:=
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}}
concurrency
:=
NewConcurrencyService
(
stubConcurrencyCache
{})
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
,
Gateway
:
config
.
GatewayConfig
{
Scheduling
:
config
.
GatewaySchedulingConfig
{
LoadBatchEnabled
:
true
,
StickySessionMaxWaiting
:
3
,
StickySessionWaitTimeout
:
time
.
Second
,
FallbackWaitTimeout
:
time
.
Second
,
FallbackMaxWaiting
:
10
,
},
},
}
baseCtx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
ForcePlatform
,
PlatformAnthropic
)
t
.
Run
(
"without_prefetch_reads_cache_once"
,
func
(
t
*
testing
.
T
)
{
cache
:=
&
stickyGatewayCacheHotpathStub
{
stickyID
:
account
.
ID
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
concurrency
,
userGroupRateCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCacheTTL
:
time
.
Minute
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
baseCtx
,
nil
,
"sess-hash"
,
""
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
account
.
ID
,
result
.
Account
.
ID
)
require
.
Equal
(
t
,
int64
(
1
),
cache
.
getCalls
.
Load
())
})
t
.
Run
(
"with_prefetch_skips_cache_read"
,
func
(
t
*
testing
.
T
)
{
cache
:=
&
stickyGatewayCacheHotpathStub
{
stickyID
:
account
.
ID
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
concurrency
,
userGroupRateCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCacheTTL
:
time
.
Minute
,
}
ctx
:=
context
.
WithValue
(
baseCtx
,
ctxkey
.
PrefetchedStickyAccountID
,
account
.
ID
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyGroupID
,
int64
(
0
))
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sess-hash"
,
""
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
account
.
ID
,
result
.
Account
.
ID
)
require
.
Equal
(
t
,
int64
(
0
),
cache
.
getCalls
.
Load
())
})
t
.
Run
(
"with_prefetch_group_mismatch_reads_cache"
,
func
(
t
*
testing
.
T
)
{
cache
:=
&
stickyGatewayCacheHotpathStub
{
stickyID
:
account
.
ID
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
concurrency
,
userGroupRateCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCacheTTL
:
time
.
Minute
,
}
ctx
:=
context
.
WithValue
(
baseCtx
,
ctxkey
.
PrefetchedStickyAccountID
,
int64
(
999
))
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyGroupID
,
int64
(
77
))
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sess-hash"
,
""
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
account
.
ID
,
result
.
Account
.
ID
)
require
.
Equal
(
t
,
int64
(
1
),
cache
.
getCalls
.
Load
())
})
}
backend/internal/service/gateway_multiplatform_test.go
View file @
6bccb8a8
...
@@ -77,6 +77,11 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun
...
@@ -77,6 +77,11 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun
func
(
m
*
mockAccountRepoForPlatform
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
Account
,
error
)
{
func
(
m
*
mockAccountRepoForPlatform
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
Account
,
error
)
{
return
nil
,
nil
return
nil
,
nil
}
}
func
(
m
*
mockAccountRepoForPlatform
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListCRSAccountIDs
(
ctx
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
func
(
m
*
mockAccountRepoForPlatform
)
ListCRSAccountIDs
(
ctx
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
return
nil
,
nil
return
nil
,
nil
}
}
...
...
Prev
1
…
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment