Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
b9b4db3d
Commit
b9b4db3d
authored
Jan 17, 2026
by
song
Browse files
Merge upstream/main
parents
5a6f60a9
dae0d532
Changes
230
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
230 of 230+
files are displayed.
Plain diff
Email patch
backend/internal/service/api_key_auth_cache_invalidate.go
0 → 100644
View file @
b9b4db3d
package
service
import
"context"
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
func
(
s
*
APIKeyService
)
InvalidateAuthCacheByKey
(
ctx
context
.
Context
,
key
string
)
{
if
key
==
""
{
return
}
cacheKey
:=
s
.
authCacheKey
(
key
)
s
.
deleteAuthCache
(
ctx
,
cacheKey
)
}
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
func
(
s
*
APIKeyService
)
InvalidateAuthCacheByUserID
(
ctx
context
.
Context
,
userID
int64
)
{
if
userID
<=
0
{
return
}
keys
,
err
:=
s
.
apiKeyRepo
.
ListKeysByUserID
(
ctx
,
userID
)
if
err
!=
nil
{
return
}
s
.
deleteAuthCacheByKeys
(
ctx
,
keys
)
}
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
func
(
s
*
APIKeyService
)
InvalidateAuthCacheByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
{
if
groupID
<=
0
{
return
}
keys
,
err
:=
s
.
apiKeyRepo
.
ListKeysByGroupID
(
ctx
,
groupID
)
if
err
!=
nil
{
return
}
s
.
deleteAuthCacheByKeys
(
ctx
,
keys
)
}
func
(
s
*
APIKeyService
)
deleteAuthCacheByKeys
(
ctx
context
.
Context
,
keys
[]
string
)
{
if
len
(
keys
)
==
0
{
return
}
for
_
,
key
:=
range
keys
{
if
key
==
""
{
continue
}
s
.
deleteAuthCache
(
ctx
,
s
.
authCacheKey
(
key
))
}
}
backend/internal/service/api_key_service.go
View file @
b9b4db3d
...
...
@@ -9,8 +9,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
var
(
...
...
@@ -20,6 +23,7 @@ var (
ErrAPIKeyTooShort
=
infraerrors
.
BadRequest
(
"API_KEY_TOO_SHORT"
,
"api key must be at least 16 characters"
)
ErrAPIKeyInvalidChars
=
infraerrors
.
BadRequest
(
"API_KEY_INVALID_CHARS"
,
"api key can only contain letters, numbers, underscores, and hyphens"
)
ErrAPIKeyRateLimited
=
infraerrors
.
TooManyRequests
(
"API_KEY_RATE_LIMITED"
,
"too many failed attempts, please try again later"
)
ErrInvalidIPPattern
=
infraerrors
.
BadRequest
(
"INVALID_IP_PATTERN"
,
"invalid IP or CIDR pattern"
)
)
const
(
...
...
@@ -29,9 +33,11 @@ const (
type
APIKeyRepository
interface
{
Create
(
ctx
context
.
Context
,
key
*
APIKey
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
APIKey
,
error
)
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除
前的轻量级权限验证
GetOwnerID
(
ctx
context
.
Context
,
id
int64
)
(
int64
,
error
)
// Get
KeyAnd
OwnerID 仅获取 API Key 的
key 与
所有者 ID,用于删除
等轻量场景
Get
KeyAnd
OwnerID
(
ctx
context
.
Context
,
id
int64
)
(
string
,
int64
,
error
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
// GetByKeyForAuth 认证专用查询,返回最小字段集
GetByKeyForAuth
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
Update
(
ctx
context
.
Context
,
key
*
APIKey
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
...
...
@@ -43,6 +49,8 @@ type APIKeyRepository interface {
SearchAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
APIKey
,
error
)
ClearGroupIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
}
// APIKeyCache defines cache operations for API key service
...
...
@@ -53,20 +61,35 @@ type APIKeyCache interface {
IncrementDailyUsage
(
ctx
context
.
Context
,
apiKey
string
)
error
SetDailyUsageExpiry
(
ctx
context
.
Context
,
apiKey
string
,
ttl
time
.
Duration
)
error
GetAuthCache
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
SetAuthCache
(
ctx
context
.
Context
,
key
string
,
entry
*
APIKeyAuthCacheEntry
,
ttl
time
.
Duration
)
error
DeleteAuthCache
(
ctx
context
.
Context
,
key
string
)
error
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
type
APIKeyAuthCacheInvalidator
interface
{
InvalidateAuthCacheByKey
(
ctx
context
.
Context
,
key
string
)
InvalidateAuthCacheByUserID
(
ctx
context
.
Context
,
userID
int64
)
InvalidateAuthCacheByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
}
// CreateAPIKeyRequest 创建API Key请求
type
CreateAPIKeyRequest
struct
{
Name
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
Name
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
}
// UpdateAPIKeyRequest 更新API Key请求
type
UpdateAPIKeyRequest
struct
{
Name
*
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
Status
*
string
`json:"status"`
Name
*
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
Status
*
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单(空数组清空)
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单(空数组清空)
}
// APIKeyService API Key服务
...
...
@@ -77,6 +100,9 @@ type APIKeyService struct {
userSubRepo
UserSubscriptionRepository
cache
APIKeyCache
cfg
*
config
.
Config
authCacheL1
*
ristretto
.
Cache
authCfg
apiKeyAuthCacheConfig
authGroup
singleflight
.
Group
}
// NewAPIKeyService 创建API Key服务实例
...
...
@@ -88,7 +114,7 @@ func NewAPIKeyService(
cache
APIKeyCache
,
cfg
*
config
.
Config
,
)
*
APIKeyService
{
return
&
APIKeyService
{
svc
:=
&
APIKeyService
{
apiKeyRepo
:
apiKeyRepo
,
userRepo
:
userRepo
,
groupRepo
:
groupRepo
,
...
...
@@ -96,6 +122,8 @@ func NewAPIKeyService(
cache
:
cache
,
cfg
:
cfg
,
}
svc
.
initAuthCache
(
cfg
)
return
svc
}
// GenerateKey 生成随机API Key
...
...
@@ -186,6 +214,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
// 验证 IP 白名单格式
if
len
(
req
.
IPWhitelist
)
>
0
{
if
invalid
:=
ip
.
ValidateIPPatterns
(
req
.
IPWhitelist
);
len
(
invalid
)
>
0
{
return
nil
,
fmt
.
Errorf
(
"%w: %v"
,
ErrInvalidIPPattern
,
invalid
)
}
}
// 验证 IP 黑名单格式
if
len
(
req
.
IPBlacklist
)
>
0
{
if
invalid
:=
ip
.
ValidateIPPatterns
(
req
.
IPBlacklist
);
len
(
invalid
)
>
0
{
return
nil
,
fmt
.
Errorf
(
"%w: %v"
,
ErrInvalidIPPattern
,
invalid
)
}
}
// 验证分组权限(如果指定了分组)
if
req
.
GroupID
!=
nil
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
req
.
GroupID
)
...
...
@@ -236,17 +278,21 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
// 创建API Key记录
apiKey
:=
&
APIKey
{
UserID
:
userID
,
Key
:
key
,
Name
:
req
.
Name
,
GroupID
:
req
.
GroupID
,
Status
:
StatusActive
,
UserID
:
userID
,
Key
:
key
,
Name
:
req
.
Name
,
GroupID
:
req
.
GroupID
,
Status
:
StatusActive
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
}
if
err
:=
s
.
apiKeyRepo
.
Create
(
ctx
,
apiKey
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create api key: %w"
,
err
)
}
s
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
return
apiKey
,
nil
}
...
...
@@ -282,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetByKey 根据Key字符串获取API Key(用于认证)
func
(
s
*
APIKeyService
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
// 尝试从Redis缓存获取
cacheKey
:=
fmt
.
Sprintf
(
"apikey:%s"
,
key
)
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByKey
(
ctx
,
key
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
cacheKey
:=
s
.
authCacheKey
(
key
)
if
entry
,
ok
:=
s
.
getAuthCacheEntry
(
ctx
,
cacheKey
);
ok
{
if
apiKey
,
used
,
err
:=
s
.
applyAuthCacheEntry
(
key
,
entry
);
used
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
return
apiKey
,
nil
}
}
// 缓存到Redis(可选,TTL设置为5分钟)
if
s
.
cache
!=
nil
{
// 这里可以序列化并缓存API Key
_
=
cacheKey
// 使用变量避免未使用错误
if
s
.
authCfg
.
singleflight
{
value
,
err
,
_
:=
s
.
authGroup
.
Do
(
cacheKey
,
func
()
(
any
,
error
)
{
return
s
.
loadAuthCacheEntry
(
ctx
,
key
,
cacheKey
)
})
if
err
!=
nil
{
return
nil
,
err
}
entry
,
_
:=
value
.
(
*
APIKeyAuthCacheEntry
)
if
apiKey
,
used
,
err
:=
s
.
applyAuthCacheEntry
(
key
,
entry
);
used
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
return
apiKey
,
nil
}
}
else
{
entry
,
err
:=
s
.
loadAuthCacheEntry
(
ctx
,
key
,
cacheKey
)
if
err
!=
nil
{
return
nil
,
err
}
if
apiKey
,
used
,
err
:=
s
.
applyAuthCacheEntry
(
key
,
entry
);
used
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
return
apiKey
,
nil
}
}
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByKeyForAuth
(
ctx
,
key
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
apiKey
.
Key
=
key
return
apiKey
,
nil
}
...
...
@@ -312,6 +386,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return
nil
,
ErrInsufficientPerms
}
// 验证 IP 白名单格式
if
len
(
req
.
IPWhitelist
)
>
0
{
if
invalid
:=
ip
.
ValidateIPPatterns
(
req
.
IPWhitelist
);
len
(
invalid
)
>
0
{
return
nil
,
fmt
.
Errorf
(
"%w: %v"
,
ErrInvalidIPPattern
,
invalid
)
}
}
// 验证 IP 黑名单格式
if
len
(
req
.
IPBlacklist
)
>
0
{
if
invalid
:=
ip
.
ValidateIPPatterns
(
req
.
IPBlacklist
);
len
(
invalid
)
>
0
{
return
nil
,
fmt
.
Errorf
(
"%w: %v"
,
ErrInvalidIPPattern
,
invalid
)
}
}
// 更新字段
if
req
.
Name
!=
nil
{
apiKey
.
Name
=
*
req
.
Name
...
...
@@ -344,19 +432,22 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
}
// 更新 IP 限制(空数组会清空设置)
apiKey
.
IPWhitelist
=
req
.
IPWhitelist
apiKey
.
IPBlacklist
=
req
.
IPBlacklist
if
err
:=
s
.
apiKeyRepo
.
Update
(
ctx
,
apiKey
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update api key: %w"
,
err
)
}
s
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
return
apiKey
,
nil
}
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
func
(
s
*
APIKeyService
)
Delete
(
ctx
context
.
Context
,
id
int64
,
userID
int64
)
error
{
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID
,
err
:=
s
.
apiKeyRepo
.
GetOwnerID
(
ctx
,
id
)
key
,
ownerID
,
err
:=
s
.
apiKeyRepo
.
GetKeyAndOwnerID
(
ctx
,
id
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
...
...
@@ -366,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
return
ErrInsufficientPerms
}
// 清除Redis缓存(使用
own
erID 而非 apiKey.UserID)
// 清除Redis缓存(使用
us
erID 而非 apiKey.UserID)
if
s
.
cache
!=
nil
{
_
=
s
.
cache
.
DeleteCreateAttemptCount
(
ctx
,
own
erID
)
_
=
s
.
cache
.
DeleteCreateAttemptCount
(
ctx
,
us
erID
)
}
s
.
InvalidateAuthCacheByKey
(
ctx
,
key
)
if
err
:=
s
.
apiKeyRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete api key: %w"
,
err
)
...
...
backend/internal/service/api_key_service_cache_test.go
0 → 100644
View file @
b9b4db3d
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
type
authRepoStub
struct
{
getByKeyForAuth
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
listKeysByUserID
func
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
listKeysByGroupID
func
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
}
func
(
s
*
authRepoStub
)
Create
(
ctx
context
.
Context
,
key
*
APIKey
)
error
{
panic
(
"unexpected Create call"
)
}
func
(
s
*
authRepoStub
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
APIKey
,
error
)
{
panic
(
"unexpected GetByID call"
)
}
func
(
s
*
authRepoStub
)
GetKeyAndOwnerID
(
ctx
context
.
Context
,
id
int64
)
(
string
,
int64
,
error
)
{
panic
(
"unexpected GetKeyAndOwnerID call"
)
}
func
(
s
*
authRepoStub
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
panic
(
"unexpected GetByKey call"
)
}
func
(
s
*
authRepoStub
)
GetByKeyForAuth
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
if
s
.
getByKeyForAuth
==
nil
{
panic
(
"unexpected GetByKeyForAuth call"
)
}
return
s
.
getByKeyForAuth
(
ctx
,
key
)
}
func
(
s
*
authRepoStub
)
Update
(
ctx
context
.
Context
,
key
*
APIKey
)
error
{
panic
(
"unexpected Update call"
)
}
func
(
s
*
authRepoStub
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
authRepoStub
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListByUserID call"
)
}
func
(
s
*
authRepoStub
)
VerifyOwnership
(
ctx
context
.
Context
,
userID
int64
,
apiKeyIDs
[]
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected VerifyOwnership call"
)
}
func
(
s
*
authRepoStub
)
CountByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
panic
(
"unexpected CountByUserID call"
)
}
func
(
s
*
authRepoStub
)
ExistsByKey
(
ctx
context
.
Context
,
key
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByKey call"
)
}
func
(
s
*
authRepoStub
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
APIKey
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListByGroupID call"
)
}
func
(
s
*
authRepoStub
)
SearchAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
APIKey
,
error
)
{
panic
(
"unexpected SearchAPIKeys call"
)
}
func
(
s
*
authRepoStub
)
ClearGroupIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
panic
(
"unexpected ClearGroupIDByGroupID call"
)
}
func
(
s
*
authRepoStub
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
panic
(
"unexpected CountByGroupID call"
)
}
func
(
s
*
authRepoStub
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
{
if
s
.
listKeysByUserID
==
nil
{
panic
(
"unexpected ListKeysByUserID call"
)
}
return
s
.
listKeysByUserID
(
ctx
,
userID
)
}
func
(
s
*
authRepoStub
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
{
if
s
.
listKeysByGroupID
==
nil
{
panic
(
"unexpected ListKeysByGroupID call"
)
}
return
s
.
listKeysByGroupID
(
ctx
,
groupID
)
}
type
authCacheStub
struct
{
getAuthCache
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
setAuthKeys
[]
string
deleteAuthKeys
[]
string
}
func
(
s
*
authCacheStub
)
GetCreateAttemptCount
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
s
*
authCacheStub
)
IncrementCreateAttemptCount
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
nil
}
func
(
s
*
authCacheStub
)
DeleteCreateAttemptCount
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
nil
}
func
(
s
*
authCacheStub
)
IncrementDailyUsage
(
ctx
context
.
Context
,
apiKey
string
)
error
{
return
nil
}
func
(
s
*
authCacheStub
)
SetDailyUsageExpiry
(
ctx
context
.
Context
,
apiKey
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
authCacheStub
)
GetAuthCache
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
if
s
.
getAuthCache
==
nil
{
return
nil
,
redis
.
Nil
}
return
s
.
getAuthCache
(
ctx
,
key
)
}
func
(
s
*
authCacheStub
)
SetAuthCache
(
ctx
context
.
Context
,
key
string
,
entry
*
APIKeyAuthCacheEntry
,
ttl
time
.
Duration
)
error
{
s
.
setAuthKeys
=
append
(
s
.
setAuthKeys
,
key
)
return
nil
}
func
(
s
*
authCacheStub
)
DeleteAuthCache
(
ctx
context
.
Context
,
key
string
)
error
{
s
.
deleteAuthKeys
=
append
(
s
.
deleteAuthKeys
,
key
)
return
nil
}
func
TestAPIKeyService_GetByKey_UsesL2Cache
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
getByKeyForAuth
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
return
nil
,
errors
.
New
(
"unexpected repo call"
)
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
L2TTLSeconds
:
60
,
NegativeTTLSeconds
:
30
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
cache
,
cfg
)
groupID
:=
int64
(
9
)
cacheEntry
:=
&
APIKeyAuthCacheEntry
{
Snapshot
:
&
APIKeyAuthSnapshot
{
APIKeyID
:
1
,
UserID
:
2
,
GroupID
:
&
groupID
,
Status
:
StatusActive
,
User
:
APIKeyAuthUserSnapshot
{
ID
:
2
,
Status
:
StatusActive
,
Role
:
RoleUser
,
Balance
:
10
,
Concurrency
:
3
,
},
Group
:
&
APIKeyAuthGroupSnapshot
{
ID
:
groupID
,
Name
:
"g"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
SubscriptionType
:
SubscriptionTypeStandard
,
RateMultiplier
:
1
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-opus-*"
:
{
1
,
2
},
},
},
},
}
cache
.
getAuthCache
=
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
return
cacheEntry
,
nil
}
apiKey
,
err
:=
svc
.
GetByKey
(
context
.
Background
(),
"k1"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
apiKey
.
ID
)
require
.
Equal
(
t
,
int64
(
2
),
apiKey
.
User
.
ID
)
require
.
Equal
(
t
,
groupID
,
apiKey
.
Group
.
ID
)
require
.
True
(
t
,
apiKey
.
Group
.
ModelRoutingEnabled
)
require
.
Equal
(
t
,
map
[
string
][]
int64
{
"claude-opus-*"
:
{
1
,
2
}},
apiKey
.
Group
.
ModelRouting
)
}
func
TestAPIKeyService_GetByKey_NegativeCache
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
getByKeyForAuth
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
return
nil
,
errors
.
New
(
"unexpected repo call"
)
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
L2TTLSeconds
:
60
,
NegativeTTLSeconds
:
30
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
cache
,
cfg
)
cache
.
getAuthCache
=
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
return
&
APIKeyAuthCacheEntry
{
NotFound
:
true
},
nil
}
_
,
err
:=
svc
.
GetByKey
(
context
.
Background
(),
"missing"
)
require
.
ErrorIs
(
t
,
err
,
ErrAPIKeyNotFound
)
}
func
TestAPIKeyService_GetByKey_CacheMissStoresL2
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
getByKeyForAuth
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
return
&
APIKey
{
ID
:
5
,
UserID
:
7
,
Status
:
StatusActive
,
User
:
&
User
{
ID
:
7
,
Status
:
StatusActive
,
Role
:
RoleUser
,
Balance
:
12
,
Concurrency
:
2
,
},
},
nil
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
L2TTLSeconds
:
60
,
NegativeTTLSeconds
:
30
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
cache
,
cfg
)
cache
.
getAuthCache
=
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
return
nil
,
redis
.
Nil
}
apiKey
,
err
:=
svc
.
GetByKey
(
context
.
Background
(),
"k2"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
5
),
apiKey
.
ID
)
require
.
Len
(
t
,
cache
.
setAuthKeys
,
1
)
}
func
TestAPIKeyService_GetByKey_UsesL1Cache
(
t
*
testing
.
T
)
{
var
calls
int32
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
getByKeyForAuth
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
atomic
.
AddInt32
(
&
calls
,
1
)
return
&
APIKey
{
ID
:
21
,
UserID
:
3
,
Status
:
StatusActive
,
User
:
&
User
{
ID
:
3
,
Status
:
StatusActive
,
Role
:
RoleUser
,
Balance
:
5
,
Concurrency
:
2
,
},
},
nil
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
L1Size
:
1000
,
L1TTLSeconds
:
60
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
cache
,
cfg
)
require
.
NotNil
(
t
,
svc
.
authCacheL1
)
_
,
err
:=
svc
.
GetByKey
(
context
.
Background
(),
"k-l1"
)
require
.
NoError
(
t
,
err
)
svc
.
authCacheL1
.
Wait
()
cacheKey
:=
svc
.
authCacheKey
(
"k-l1"
)
_
,
ok
:=
svc
.
authCacheL1
.
Get
(
cacheKey
)
require
.
True
(
t
,
ok
)
_
,
err
=
svc
.
GetByKey
(
context
.
Background
(),
"k-l1"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
calls
))
}
func
TestAPIKeyService_InvalidateAuthCacheByUserID
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
listKeysByUserID
:
func
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
{
return
[]
string
{
"k1"
,
"k2"
},
nil
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
L2TTLSeconds
:
60
,
NegativeTTLSeconds
:
30
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
cache
,
cfg
)
svc
.
InvalidateAuthCacheByUserID
(
context
.
Background
(),
7
)
require
.
Len
(
t
,
cache
.
deleteAuthKeys
,
2
)
}
func
TestAPIKeyService_InvalidateAuthCacheByGroupID
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
listKeysByGroupID
:
func
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
{
return
[]
string
{
"k1"
,
"k2"
},
nil
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
L2TTLSeconds
:
60
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
cache
,
cfg
)
svc
.
InvalidateAuthCacheByGroupID
(
context
.
Background
(),
9
)
require
.
Len
(
t
,
cache
.
deleteAuthKeys
,
2
)
}
func
TestAPIKeyService_InvalidateAuthCacheByKey
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
listKeysByUserID
:
func
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
{
return
nil
,
nil
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
L2TTLSeconds
:
60
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
cache
,
cfg
)
svc
.
InvalidateAuthCacheByKey
(
context
.
Background
(),
"k1"
)
require
.
Len
(
t
,
cache
.
deleteAuthKeys
,
1
)
}
func
TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss
(
t
*
testing
.
T
)
{
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
getByKeyForAuth
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
return
nil
,
ErrAPIKeyNotFound
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
L2TTLSeconds
:
60
,
NegativeTTLSeconds
:
30
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
cache
,
cfg
)
cache
.
getAuthCache
=
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
return
nil
,
redis
.
Nil
}
_
,
err
:=
svc
.
GetByKey
(
context
.
Background
(),
"missing"
)
require
.
ErrorIs
(
t
,
err
,
ErrAPIKeyNotFound
)
require
.
Len
(
t
,
cache
.
setAuthKeys
,
1
)
}
func
TestAPIKeyService_GetByKey_SingleflightCollapses
(
t
*
testing
.
T
)
{
var
calls
int32
cache
:=
&
authCacheStub
{}
repo
:=
&
authRepoStub
{
getByKeyForAuth
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
atomic
.
AddInt32
(
&
calls
,
1
)
time
.
Sleep
(
50
*
time
.
Millisecond
)
return
&
APIKey
{
ID
:
11
,
UserID
:
2
,
Status
:
StatusActive
,
User
:
&
User
{
ID
:
2
,
Status
:
StatusActive
,
Role
:
RoleUser
,
Balance
:
1
,
Concurrency
:
1
,
},
},
nil
},
}
cfg
:=
&
config
.
Config
{
APIKeyAuth
:
config
.
APIKeyAuthCacheConfig
{
Singleflight
:
true
,
},
}
svc
:=
NewAPIKeyService
(
repo
,
nil
,
nil
,
nil
,
cache
,
cfg
)
start
:=
make
(
chan
struct
{})
wg
:=
sync
.
WaitGroup
{}
errs
:=
make
([]
error
,
5
)
for
i
:=
0
;
i
<
5
;
i
++
{
wg
.
Add
(
1
)
go
func
(
idx
int
)
{
defer
wg
.
Done
()
<-
start
_
,
err
:=
svc
.
GetByKey
(
context
.
Background
(),
"k1"
)
errs
[
idx
]
=
err
}(
i
)
}
close
(
start
)
wg
.
Wait
()
for
_
,
err
:=
range
errs
{
require
.
NoError
(
t
,
err
)
}
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
calls
))
}
backend/internal/service/api_key_service_delete_test.go
View file @
b9b4db3d
...
...
@@ -20,13 +20,12 @@ import (
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
type
apiKeyRepoStub
struct
{
ownerID
int64
// GetOwnerID 的返回值
owner
Err
error
// GetOwnerID 的错误返回值
apiKey
*
APIKey
// Get
KeyAnd
OwnerID 的返回值
getByID
Err
error
// Get
KeyAnd
OwnerID 的错误返回值
deleteErr
error
// Delete 的错误返回值
deletedIDs
[]
int64
// 记录已删除的 API Key ID 列表
}
...
...
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
}
func
(
s
*
apiKeyRepoStub
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
APIKey
,
error
)
{
if
s
.
getByIDErr
!=
nil
{
return
nil
,
s
.
getByIDErr
}
if
s
.
apiKey
!=
nil
{
clone
:=
*
s
.
apiKey
return
&
clone
,
nil
}
panic
(
"unexpected GetByID call"
)
}
// GetOwnerID 返回预设的所有者 ID 或错误。
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
func
(
s
*
apiKeyRepoStub
)
GetOwnerID
(
ctx
context
.
Context
,
id
int64
)
(
int64
,
error
)
{
return
s
.
ownerID
,
s
.
ownerErr
func
(
s
*
apiKeyRepoStub
)
GetKeyAndOwnerID
(
ctx
context
.
Context
,
id
int64
)
(
string
,
int64
,
error
)
{
if
s
.
getByIDErr
!=
nil
{
return
""
,
0
,
s
.
getByIDErr
}
if
s
.
apiKey
!=
nil
{
return
s
.
apiKey
.
Key
,
s
.
apiKey
.
UserID
,
nil
}
return
""
,
0
,
ErrAPIKeyNotFound
}
func
(
s
*
apiKeyRepoStub
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
panic
(
"unexpected GetByKey call"
)
}
func
(
s
*
apiKeyRepoStub
)
GetByKeyForAuth
(
ctx
context
.
Context
,
key
string
)
(
*
APIKey
,
error
)
{
panic
(
"unexpected GetByKeyForAuth call"
)
}
func
(
s
*
apiKeyRepoStub
)
Update
(
ctx
context
.
Context
,
key
*
APIKey
)
error
{
panic
(
"unexpected Update call"
)
}
...
...
@@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic
(
"unexpected CountByGroupID call"
)
}
func
(
s
*
apiKeyRepoStub
)
ListKeysByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
string
,
error
)
{
panic
(
"unexpected ListKeysByUserID call"
)
}
func
(
s
*
apiKeyRepoStub
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
{
panic
(
"unexpected ListKeysByGroupID call"
)
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
// - invalidated: 记录被清除缓存的用户 ID 列表
type
apiKeyCacheStub
struct
{
invalidated
[]
int64
// 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
invalidated
[]
int64
// 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
deleteAuthKeys
[]
string
// 记录调用 DeleteAuthCache 时传入的缓存 key
}
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
...
...
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return
nil
}
func
(
s
*
apiKeyCacheStub
)
GetAuthCache
(
ctx
context
.
Context
,
key
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
return
nil
,
nil
}
func
(
s
*
apiKeyCacheStub
)
SetAuthCache
(
ctx
context
.
Context
,
key
string
,
entry
*
APIKeyAuthCacheEntry
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
apiKeyCacheStub
)
DeleteAuthCache
(
ctx
context
.
Context
,
key
string
)
error
{
s
.
deleteAuthKeys
=
append
(
s
.
deleteAuthKeys
,
key
)
return
nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 1
// - Get
KeyAnd
OwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2(不匹配)
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func
TestApiKeyService_Delete_OwnerMismatch
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{
ownerID
:
1
}
repo
:=
&
apiKeyRepoStub
{
apiKey
:
&
APIKey
{
ID
:
10
,
UserID
:
1
,
Key
:
"k"
},
}
cache
:=
&
apiKeyCacheStub
{}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
,
cache
:
cache
}
...
...
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
require
.
ErrorIs
(
t
,
err
,
ErrInsufficientPerms
)
require
.
Empty
(
t
,
repo
.
deletedIDs
)
// 验证删除操作未被调用
require
.
Empty
(
t
,
cache
.
invalidated
)
// 验证缓存未被清除
require
.
Empty
(
t
,
cache
.
deleteAuthKeys
)
}
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 7
// - Get
KeyAnd
OwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7(匹配)
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID)
// - 返回 nil 错误
func
TestApiKeyService_Delete_Success
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{
ownerID
:
7
}
repo
:=
&
apiKeyRepoStub
{
apiKey
:
&
APIKey
{
ID
:
42
,
UserID
:
7
,
Key
:
"k"
},
}
cache
:=
&
apiKeyCacheStub
{}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
,
cache
:
cache
}
...
...
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
int64
{
42
},
repo
.
deletedIDs
)
// 验证正确的 API Key 被删除
require
.
Equal
(
t
,
[]
int64
{
7
},
cache
.
invalidated
)
// 验证所有者的缓存被清除
require
.
Equal
(
t
,
[]
string
{
svc
.
authCacheKey
(
"k"
)},
cache
.
deleteAuthKeys
)
}
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - Get
KeyAnd
OwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func
TestApiKeyService_Delete_NotFound
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{
owner
Err
:
ErrAPIKeyNotFound
}
repo
:=
&
apiKeyRepoStub
{
getByID
Err
:
ErrAPIKeyNotFound
}
cache
:=
&
apiKeyCacheStub
{}
svc
:=
&
APIKeyService
{
apiKeyRepo
:
repo
,
cache
:
cache
}
...
...
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
require
.
ErrorIs
(
t
,
err
,
ErrAPIKeyNotFound
)
require
.
Empty
(
t
,
repo
.
deletedIDs
)
require
.
Empty
(
t
,
cache
.
invalidated
)
require
.
Empty
(
t
,
cache
.
deleteAuthKeys
)
}
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetOwnerID 返回正确的所有者 ID
// - Get
KeyAnd
OwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func
TestApiKeyService_Delete_DeleteFails
(
t
*
testing
.
T
)
{
repo
:=
&
apiKeyRepoStub
{
ownerID
:
3
,
apiKey
:
&
APIKey
{
ID
:
42
,
UserID
:
3
,
Key
:
"k"
}
,
deleteErr
:
errors
.
New
(
"delete failed"
),
}
cache
:=
&
apiKeyCacheStub
{}
...
...
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
require
.
ErrorContains
(
t
,
err
,
"delete api key"
)
require
.
Equal
(
t
,
[]
int64
{
3
},
repo
.
deletedIDs
)
// 验证删除操作被调用
require
.
Equal
(
t
,
[]
int64
{
3
},
cache
.
invalidated
)
// 验证缓存已被清除(即使删除失败)
require
.
Equal
(
t
,
[]
string
{
svc
.
authCacheKey
(
"k"
)},
cache
.
deleteAuthKeys
)
}
backend/internal/service/auth_cache_invalidation_test.go
0 → 100644
View file @
b9b4db3d
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/stretchr/testify/require"
)
func
TestUsageService_InvalidateUsageCaches
(
t
*
testing
.
T
)
{
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
UsageService
{
authCacheInvalidator
:
invalidator
}
svc
.
invalidateUsageCaches
(
context
.
Background
(),
7
,
false
)
require
.
Empty
(
t
,
invalidator
.
userIDs
)
svc
.
invalidateUsageCaches
(
context
.
Background
(),
7
,
true
)
require
.
Equal
(
t
,
[]
int64
{
7
},
invalidator
.
userIDs
)
}
func
TestRedeemService_InvalidateRedeemCaches_AuthCache
(
t
*
testing
.
T
)
{
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
RedeemService
{
authCacheInvalidator
:
invalidator
}
svc
.
invalidateRedeemCaches
(
context
.
Background
(),
11
,
&
RedeemCode
{
Type
:
RedeemTypeBalance
})
svc
.
invalidateRedeemCaches
(
context
.
Background
(),
11
,
&
RedeemCode
{
Type
:
RedeemTypeConcurrency
})
groupID
:=
int64
(
3
)
svc
.
invalidateRedeemCaches
(
context
.
Background
(),
11
,
&
RedeemCode
{
Type
:
RedeemTypeSubscription
,
GroupID
:
&
groupID
})
require
.
Equal
(
t
,
[]
int64
{
11
,
11
,
11
},
invalidator
.
userIDs
)
}
backend/internal/service/auth_service.go
View file @
b9b4db3d
...
...
@@ -52,6 +52,7 @@ type AuthService struct {
emailService
*
EmailService
turnstileService
*
TurnstileService
emailQueueService
*
EmailQueueService
promoService
*
PromoService
}
// NewAuthService 创建认证服务实例
...
...
@@ -62,6 +63,7 @@ func NewAuthService(
emailService
*
EmailService
,
turnstileService
*
TurnstileService
,
emailQueueService
*
EmailQueueService
,
promoService
*
PromoService
,
)
*
AuthService
{
return
&
AuthService
{
userRepo
:
userRepo
,
...
...
@@ -70,16 +72,17 @@ func NewAuthService(
emailService
:
emailService
,
turnstileService
:
turnstileService
,
emailQueueService
:
emailQueueService
,
promoService
:
promoService
,
}
}
// Register 用户注册,返回token和用户
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
)
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
)
}
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
string
)
(
string
,
*
User
,
error
)
{
// RegisterWithVerification 用户注册(支持邮件验证
和优惠码
),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
...
...
@@ -150,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrServiceUnavailable
}
// 应用优惠码(如果提供)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
// 优惠码应用失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to apply promo code for user %d: %v"
,
user
.
ID
,
err
)
}
else
{
// 重新获取用户信息以获取更新后的余额
if
updatedUser
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
user
.
ID
);
err
==
nil
{
user
=
updatedUser
}
}
}
// 生成token
token
,
err
:=
s
.
GenerateToken
(
user
)
if
err
!=
nil
{
...
...
@@ -341,7 +357,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// - 如果邮箱已存在:直接登录(不需要本地密码)
// - 如果邮箱不存在:创建新用户并登录
//
// 注意:该函数用于
“终端用户登录 Sub2API 本身”的
场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
// 注意:该函数用于
LinuxDo OAuth 登录
场景(不同于上游账号的 OAuth,例如
Claude/
OpenAI/Gemini)。
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
func
(
s
*
AuthService
)
LoginOrRegisterOAuth
(
ctx
context
.
Context
,
email
,
username
string
)
(
string
,
*
User
,
error
)
{
email
=
strings
.
TrimSpace
(
email
)
...
...
@@ -360,8 +376,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
// OAuth 首次登录视为注册
。
if
s
.
settingService
!
=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
// OAuth 首次登录视为注册
(fail-close:settingService 未配置时不允许注册)
if
s
.
settingService
=
=
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
}
...
...
backend/internal/service/auth_service_register_test.go
View file @
b9b4db3d
...
...
@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
emailService
,
nil
,
nil
,
nil
,
// promoService
)
}
...
...
@@ -131,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
},
nil
)
// 应返回服务不可用错误,而不是允许绕过验证
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
...
...
@@ -143,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
}
...
...
@@ -157,7 +158,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidVerifyCode
)
require
.
ErrorContains
(
t
,
err
,
"verify code"
)
}
...
...
backend/internal/service/claude_token_provider.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
const
(
claudeTokenRefreshSkew
=
3
*
time
.
Minute
claudeTokenCacheSkew
=
5
*
time
.
Minute
claudeLockWaitTime
=
200
*
time
.
Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
ClaudeTokenCache
=
GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
type
ClaudeTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
ClaudeTokenCache
oauthService
*
OAuthService
}
func
NewClaudeTokenProvider
(
accountRepo
AccountRepository
,
tokenCache
ClaudeTokenCache
,
oauthService
*
OAuthService
,
)
*
ClaudeTokenProvider
{
return
&
ClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
oauthService
:
oauthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
ClaudeTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAnthropic
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an anthropic oauth account"
)
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
else
if
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_get_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
slog
.
Debug
(
"claude_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2. 如果即将过期则刷新
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog
.
Warn
(
"claude_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog
.
Warn
(
"claude_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"claude_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
{
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time
.
Sleep
(
claudeLockWaitTime
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
slog
.
Debug
(
"claude_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
claudeTokenCacheSkew
:
ttl
=
until
-
claudeTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
accessToken
,
nil
}
backend/internal/service/claude_token_provider_test.go
0 → 100644
View file @
b9b4db3d
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// claudeTokenCacheStub implements ClaudeTokenCache for testing
type
claudeTokenCacheStub
struct
{
mu
sync
.
Mutex
tokens
map
[
string
]
string
getErr
error
setErr
error
deleteErr
error
lockAcquired
bool
lockErr
error
releaseLockErr
error
getCalled
int32
setCalled
int32
lockCalled
int32
unlockCalled
int32
simulateLockRace
bool
}
func
newClaudeTokenCacheStub
()
*
claudeTokenCacheStub
{
return
&
claudeTokenCacheStub
{
tokens
:
make
(
map
[
string
]
string
),
lockAcquired
:
true
,
}
}
func
(
s
*
claudeTokenCacheStub
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
atomic
.
AddInt32
(
&
s
.
getCalled
,
1
)
if
s
.
getErr
!=
nil
{
return
""
,
s
.
getErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
return
s
.
tokens
[
cacheKey
],
nil
}
func
(
s
*
claudeTokenCacheStub
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
atomic
.
AddInt32
(
&
s
.
setCalled
,
1
)
if
s
.
setErr
!=
nil
{
return
s
.
setErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
tokens
[
cacheKey
]
=
token
return
nil
}
func
(
s
*
claudeTokenCacheStub
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
if
s
.
deleteErr
!=
nil
{
return
s
.
deleteErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
delete
(
s
.
tokens
,
cacheKey
)
return
nil
}
func
(
s
*
claudeTokenCacheStub
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
atomic
.
AddInt32
(
&
s
.
lockCalled
,
1
)
if
s
.
lockErr
!=
nil
{
return
false
,
s
.
lockErr
}
if
s
.
simulateLockRace
{
return
false
,
nil
}
return
s
.
lockAcquired
,
nil
}
func
(
s
*
claudeTokenCacheStub
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
atomic
.
AddInt32
(
&
s
.
unlockCalled
,
1
)
return
s
.
releaseLockErr
}
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
type
claudeAccountRepoStub
struct
{
account
*
Account
getErr
error
updateErr
error
getCalled
int32
updateCalled
int32
}
func
(
r
*
claudeAccountRepoStub
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
atomic
.
AddInt32
(
&
r
.
getCalled
,
1
)
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
return
r
.
account
,
nil
}
func
(
r
*
claudeAccountRepoStub
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
atomic
.
AddInt32
(
&
r
.
updateCalled
,
1
)
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
account
=
account
return
nil
}
// claudeOAuthServiceStub implements OAuthService methods for testing
type
claudeOAuthServiceStub
struct
{
tokenInfo
*
TokenInfo
refreshErr
error
refreshCalled
int32
}
func
(
s
*
claudeOAuthServiceStub
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
TokenInfo
,
error
)
{
atomic
.
AddInt32
(
&
s
.
refreshCalled
,
1
)
if
s
.
refreshErr
!=
nil
{
return
nil
,
s
.
refreshErr
}
return
s
.
tokenInfo
,
nil
}
// testClaudeTokenProvider is a test version that uses the stub OAuth service
type
testClaudeTokenProvider
struct
{
accountRepo
*
claudeAccountRepoStub
tokenCache
*
claudeTokenCacheStub
oauthService
*
claudeOAuthServiceStub
}
func
(
p
*
testClaudeTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAnthropic
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an anthropic oauth account"
)
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// 1. Check cache
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
// 2. Check if refresh needed
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// Check cache again after acquiring lock
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
// Get fresh account from DB
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
// Build new credentials
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_at"
]
=
time
.
Now
()
.
Add
(
time
.
Duration
(
tokenInfo
.
ExpiresIn
)
*
time
.
Second
)
.
Format
(
time
.
RFC3339
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
account
.
Credentials
=
newCredentials
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
p
.
tokenCache
.
simulateLockRace
{
// Wait and retry cache
time
.
Sleep
(
10
*
time
.
Millisecond
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. Store in cache
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
ttl
=
time
.
Minute
// 刷新失败时使用短 TTL
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
if
until
>
claudeTokenCacheSkew
{
ttl
=
until
-
claudeTokenCacheSkew
}
else
if
until
>
0
{
ttl
=
until
}
else
{
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
return
accessToken
,
nil
}
func
TestClaudeTokenProvider_CacheHit
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
account
:=
&
Account
{
ID
:
100
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"db-token"
,
},
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
cache
.
tokens
[
cacheKey
]
=
"cached-token"
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"cached-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
getCalled
))
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
setCalled
))
}
func
TestClaudeTokenProvider_CacheMiss_FromCredentials
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
// Token expires in far future, no refresh needed
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
101
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"credential-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"credential-token"
,
token
)
// Should have stored in cache
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
require
.
Equal
(
t
,
"credential-token"
,
cache
.
tokens
[
cacheKey
])
}
func
TestClaudeTokenProvider_TokenRefresh
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh-token"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
ExpiresAt
:
time
.
Now
()
.
Add
(
time
.
Hour
)
.
Unix
(),
},
}
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
102
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
oauthService
.
refreshCalled
))
}
func
TestClaudeTokenProvider_LockRaceCondition
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
simulateLockRace
=
true
accountRepo
:=
&
claudeAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"race-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
// Simulate another worker already refreshed and cached
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestClaudeTokenProvider_NilAccount
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"account is nil"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_WrongPlatform
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
104
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_WrongAccountType
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
105
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_SetupTokenType
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
106
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeSetupToken
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_NilCache
(
t
*
testing
.
T
)
{
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
107
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"nocache-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"nocache-token"
,
token
)
}
func
TestClaudeTokenProvider_CacheGetError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
getErr
=
errors
.
New
(
"redis connection failed"
)
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
108
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
// Should gracefully degrade and return from credentials
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-token"
,
token
)
}
func
TestClaudeTokenProvider_CacheSetError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
setErr
=
errors
.
New
(
"redis write failed"
)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
109
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"still-works-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
// Should still work even if cache set fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"still-works-token"
,
token
)
}
func
TestClaudeTokenProvider_MissingAccessToken
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
110
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// missing access_token
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_RefreshError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
refreshErr
:
errors
.
New
(
"oauth refresh failed"
),
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
111
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Now with fallback behavior, should return existing token even if refresh fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestClaudeTokenProvider_OAuthServiceNotConfigured
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
112
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
nil
,
// not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestClaudeTokenProvider_TTLCalculation
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
expiresIn
time
.
Duration
}{
{
name
:
"far_future_expiry"
,
expiresIn
:
1
*
time
.
Hour
,
},
{
name
:
"medium_expiry"
,
expiresIn
:
10
*
time
.
Minute
,
},
{
name
:
"near_expiry"
,
expiresIn
:
6
*
time
.
Minute
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
tt
.
expiresIn
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
_
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Verify token was cached
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
require
.
Equal
(
t
,
"test-token"
,
cache
.
tokens
[
cacheKey
])
})
}
}
func
TestClaudeTokenProvider_AccountRepoGetError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{
getErr
:
errors
.
New
(
"db connection failed"
),
}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
113
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Should still work, just using the passed-in account
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
}
func
TestClaudeTokenProvider_AccountUpdateError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{
updateErr
:
errors
.
New
(
"db write failed"
),
}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
114
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Should still return token even if update fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
}
func
TestClaudeTokenProvider_RefreshPreservesExistingCredentials
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"new-access-token"
,
RefreshToken
:
"new-refresh-token"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
115
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-access-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
"custom_field"
:
"should-be-preserved"
,
"organization"
:
"test-org"
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"new-access-token"
,
token
)
// Verify existing fields are preserved
require
.
Equal
(
t
,
"should-be-preserved"
,
accountRepo
.
account
.
Credentials
[
"custom_field"
])
require
.
Equal
(
t
,
"test-org"
,
accountRepo
.
account
.
Credentials
[
"organization"
])
// Verify new fields are updated
require
.
Equal
(
t
,
"new-access-token"
,
accountRepo
.
account
.
Credentials
[
"access_token"
])
require
.
Equal
(
t
,
"new-refresh-token"
,
accountRepo
.
account
.
Credentials
[
"refresh_token"
])
}
func
TestClaudeTokenProvider_DoubleCheckCacheAfterLock
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
116
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// After lock is acquired, cache should have the token (simulating another worker)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"cached-by-other-worker"
cache
.
mu
.
Unlock
()
}()
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
// Tests for real provider - to increase coverage
func
TestClaudeTokenProvider_Real_LockFailedWait
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
300
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
100
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"refreshed-by-other"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestClaudeTokenProvider_Real_CacheHitAfterWait
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
301
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"original-token"
,
"expires_at"
:
expiresAt
,
},
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// Set token in cache immediately after wait starts
go
func
()
{
time
.
Sleep
(
50
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestClaudeTokenProvider_Real_NoExpiresAt
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockAcquired
=
false
// Prevent entering refresh logic
// Token with nil expires_at (no expiry set)
account
:=
&
Account
{
ID
:
302
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"no-expiry-token"
,
},
}
// After lock wait, return token from credentials
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"no-expiry-token"
,
token
)
}
func
TestClaudeTokenProvider_Real_WhitespaceToken
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cacheKey
:=
"claude:account:303"
cache
.
tokens
[
cacheKey
]
=
" "
// Whitespace only - should be treated as empty
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
303
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"real-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"real-token"
,
token
)
}
func
TestClaudeTokenProvider_Real_EmptyCredentialToken
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
304
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
" "
,
// Whitespace only
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_Real_LockError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockErr
=
errors
.
New
(
"redis lock failed"
)
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
305
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-on-lock-error"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-on-lock-error"
,
token
)
}
func
TestClaudeTokenProvider_Real_NilCredentials
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
306
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// No access_token
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
backend/internal/service/dashboard_aggregation_service.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"errors"
"log"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
const
(
defaultDashboardAggregationTimeout
=
2
*
time
.
Minute
defaultDashboardAggregationBackfillTimeout
=
30
*
time
.
Minute
dashboardAggregationRetentionInterval
=
6
*
time
.
Hour
)
var
(
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
ErrDashboardBackfillDisabled
=
errors
.
New
(
"仪表盘聚合回填已禁用"
)
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge
=
errors
.
New
(
"回填时间跨度过大"
)
)
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type
DashboardAggregationRepository
interface
{
AggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
UpdateAggregationWatermark
(
ctx
context
.
Context
,
aggregatedAt
time
.
Time
)
error
CleanupAggregates
(
ctx
context
.
Context
,
hourlyCutoff
,
dailyCutoff
time
.
Time
)
error
CleanupUsageLogs
(
ctx
context
.
Context
,
cutoff
time
.
Time
)
error
EnsureUsageLogsPartitions
(
ctx
context
.
Context
,
now
time
.
Time
)
error
}
// DashboardAggregationService 负责定时聚合与回填。
type
DashboardAggregationService
struct
{
repo
DashboardAggregationRepository
timingWheel
*
TimingWheelService
cfg
config
.
DashboardAggregationConfig
running
int32
lastRetentionCleanup
atomic
.
Value
// time.Time
}
// NewDashboardAggregationService 创建聚合服务。
func
NewDashboardAggregationService
(
repo
DashboardAggregationRepository
,
timingWheel
*
TimingWheelService
,
cfg
*
config
.
Config
)
*
DashboardAggregationService
{
var
aggCfg
config
.
DashboardAggregationConfig
if
cfg
!=
nil
{
aggCfg
=
cfg
.
DashboardAgg
}
return
&
DashboardAggregationService
{
repo
:
repo
,
timingWheel
:
timingWheel
,
cfg
:
aggCfg
,
}
}
// Start 启动定时聚合作业(重启生效配置)。
func
(
s
*
DashboardAggregationService
)
Start
()
{
if
s
==
nil
||
s
.
repo
==
nil
||
s
.
timingWheel
==
nil
{
return
}
if
!
s
.
cfg
.
Enabled
{
log
.
Printf
(
"[DashboardAggregation] 聚合作业已禁用"
)
return
}
interval
:=
time
.
Duration
(
s
.
cfg
.
IntervalSeconds
)
*
time
.
Second
if
interval
<=
0
{
interval
=
time
.
Minute
}
if
s
.
cfg
.
RecomputeDays
>
0
{
go
s
.
recomputeRecentDays
()
}
s
.
timingWheel
.
ScheduleRecurring
(
"dashboard:aggregation"
,
interval
,
func
()
{
s
.
runScheduledAggregation
()
})
log
.
Printf
(
"[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)"
,
interval
,
s
.
cfg
.
LookbackSeconds
)
if
!
s
.
cfg
.
BackfillEnabled
{
log
.
Printf
(
"[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填"
)
}
}
// TriggerBackfill 触发回填(异步)。
func
(
s
*
DashboardAggregationService
)
TriggerBackfill
(
start
,
end
time
.
Time
)
error
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
errors
.
New
(
"聚合服务未初始化"
)
}
if
!
s
.
cfg
.
BackfillEnabled
{
log
.
Printf
(
"[DashboardAggregation] 回填被拒绝: backfill_enabled=false"
)
return
ErrDashboardBackfillDisabled
}
if
!
end
.
After
(
start
)
{
return
errors
.
New
(
"回填时间范围无效"
)
}
if
s
.
cfg
.
BackfillMaxDays
>
0
{
maxRange
:=
time
.
Duration
(
s
.
cfg
.
BackfillMaxDays
)
*
24
*
time
.
Hour
if
end
.
Sub
(
start
)
>
maxRange
{
return
ErrDashboardBackfillTooLarge
}
}
go
func
()
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
defaultDashboardAggregationBackfillTimeout
)
defer
cancel
()
if
err
:=
s
.
backfillRange
(
ctx
,
start
,
end
);
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 回填失败: %v"
,
err
)
}
}()
return
nil
}
func
(
s
*
DashboardAggregationService
)
recomputeRecentDays
()
{
days
:=
s
.
cfg
.
RecomputeDays
if
days
<=
0
{
return
}
now
:=
time
.
Now
()
.
UTC
()
start
:=
now
.
AddDate
(
0
,
0
,
-
days
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
defaultDashboardAggregationBackfillTimeout
)
defer
cancel
()
if
err
:=
s
.
backfillRange
(
ctx
,
start
,
now
);
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 启动重算失败: %v"
,
err
)
return
}
}
func
(
s
*
DashboardAggregationService
)
runScheduledAggregation
()
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
return
}
defer
atomic
.
StoreInt32
(
&
s
.
running
,
0
)
jobStart
:=
time
.
Now
()
.
UTC
()
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
defaultDashboardAggregationTimeout
)
defer
cancel
()
now
:=
time
.
Now
()
.
UTC
()
last
,
err
:=
s
.
repo
.
GetAggregationWatermark
(
ctx
)
if
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 读取水位失败: %v"
,
err
)
last
=
time
.
Unix
(
0
,
0
)
.
UTC
()
}
lookback
:=
time
.
Duration
(
s
.
cfg
.
LookbackSeconds
)
*
time
.
Second
epoch
:=
time
.
Unix
(
0
,
0
)
.
UTC
()
start
:=
last
.
Add
(
-
lookback
)
if
!
last
.
After
(
epoch
)
{
retentionDays
:=
s
.
cfg
.
Retention
.
UsageLogsDays
if
retentionDays
<=
0
{
retentionDays
=
1
}
start
=
truncateToDayUTC
(
now
.
AddDate
(
0
,
0
,
-
retentionDays
))
}
else
if
start
.
After
(
now
)
{
start
=
now
.
Add
(
-
lookback
)
}
if
err
:=
s
.
aggregateRange
(
ctx
,
start
,
now
);
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 聚合失败: %v"
,
err
)
return
}
updateErr
:=
s
.
repo
.
UpdateAggregationWatermark
(
ctx
,
now
)
if
updateErr
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 更新水位失败: %v"
,
updateErr
)
}
log
.
Printf
(
"[DashboardAggregation] 聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)"
,
start
.
Format
(
time
.
RFC3339
),
now
.
Format
(
time
.
RFC3339
),
time
.
Since
(
jobStart
)
.
String
(),
updateErr
==
nil
,
)
s
.
maybeCleanupRetention
(
ctx
,
now
)
}
func
(
s
*
DashboardAggregationService
)
backfillRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
return
errors
.
New
(
"聚合作业正在运行"
)
}
defer
atomic
.
StoreInt32
(
&
s
.
running
,
0
)
jobStart
:=
time
.
Now
()
.
UTC
()
startUTC
:=
start
.
UTC
()
endUTC
:=
end
.
UTC
()
if
!
endUTC
.
After
(
startUTC
)
{
return
errors
.
New
(
"回填时间范围无效"
)
}
cursor
:=
truncateToDayUTC
(
startUTC
)
for
cursor
.
Before
(
endUTC
)
{
windowEnd
:=
cursor
.
Add
(
24
*
time
.
Hour
)
if
windowEnd
.
After
(
endUTC
)
{
windowEnd
=
endUTC
}
if
err
:=
s
.
aggregateRange
(
ctx
,
cursor
,
windowEnd
);
err
!=
nil
{
return
err
}
cursor
=
windowEnd
}
updateErr
:=
s
.
repo
.
UpdateAggregationWatermark
(
ctx
,
endUTC
)
if
updateErr
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 更新水位失败: %v"
,
updateErr
)
}
log
.
Printf
(
"[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)"
,
startUTC
.
Format
(
time
.
RFC3339
),
endUTC
.
Format
(
time
.
RFC3339
),
time
.
Since
(
jobStart
)
.
String
(),
updateErr
==
nil
,
)
s
.
maybeCleanupRetention
(
ctx
,
endUTC
)
return
nil
}
func
(
s
*
DashboardAggregationService
)
aggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
if
!
end
.
After
(
start
)
{
return
nil
}
if
err
:=
s
.
repo
.
EnsureUsageLogsPartitions
(
ctx
,
end
);
err
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 分区检查失败: %v"
,
err
)
}
return
s
.
repo
.
AggregateRange
(
ctx
,
start
,
end
)
}
func
(
s
*
DashboardAggregationService
)
maybeCleanupRetention
(
ctx
context
.
Context
,
now
time
.
Time
)
{
lastAny
:=
s
.
lastRetentionCleanup
.
Load
()
if
lastAny
!=
nil
{
if
last
,
ok
:=
lastAny
.
(
time
.
Time
);
ok
&&
now
.
Sub
(
last
)
<
dashboardAggregationRetentionInterval
{
return
}
}
hourlyCutoff
:=
now
.
AddDate
(
0
,
0
,
-
s
.
cfg
.
Retention
.
HourlyDays
)
dailyCutoff
:=
now
.
AddDate
(
0
,
0
,
-
s
.
cfg
.
Retention
.
DailyDays
)
usageCutoff
:=
now
.
AddDate
(
0
,
0
,
-
s
.
cfg
.
Retention
.
UsageLogsDays
)
aggErr
:=
s
.
repo
.
CleanupAggregates
(
ctx
,
hourlyCutoff
,
dailyCutoff
)
if
aggErr
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] 聚合保留清理失败: %v"
,
aggErr
)
}
usageErr
:=
s
.
repo
.
CleanupUsageLogs
(
ctx
,
usageCutoff
)
if
usageErr
!=
nil
{
log
.
Printf
(
"[DashboardAggregation] usage_logs 保留清理失败: %v"
,
usageErr
)
}
if
aggErr
==
nil
&&
usageErr
==
nil
{
s
.
lastRetentionCleanup
.
Store
(
now
)
}
}
func
truncateToDayUTC
(
t
time
.
Time
)
time
.
Time
{
t
=
t
.
UTC
()
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
time
.
UTC
)
}
backend/internal/service/dashboard_aggregation_service_test.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type
dashboardAggregationRepoTestStub
struct
{
aggregateCalls
int
lastStart
time
.
Time
lastEnd
time
.
Time
watermark
time
.
Time
aggregateErr
error
cleanupAggregatesErr
error
cleanupUsageErr
error
}
func
(
s
*
dashboardAggregationRepoTestStub
)
AggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
s
.
aggregateCalls
++
s
.
lastStart
=
start
s
.
lastEnd
=
end
return
s
.
aggregateErr
}
func
(
s
*
dashboardAggregationRepoTestStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
return
s
.
watermark
,
nil
}
func
(
s
*
dashboardAggregationRepoTestStub
)
UpdateAggregationWatermark
(
ctx
context
.
Context
,
aggregatedAt
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardAggregationRepoTestStub
)
CleanupAggregates
(
ctx
context
.
Context
,
hourlyCutoff
,
dailyCutoff
time
.
Time
)
error
{
return
s
.
cleanupAggregatesErr
}
func
(
s
*
dashboardAggregationRepoTestStub
)
CleanupUsageLogs
(
ctx
context
.
Context
,
cutoff
time
.
Time
)
error
{
return
s
.
cleanupUsageErr
}
func
(
s
*
dashboardAggregationRepoTestStub
)
EnsureUsageLogsPartitions
(
ctx
context
.
Context
,
now
time
.
Time
)
error
{
return
nil
}
func
TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardAggregationRepoTestStub
{
watermark
:
time
.
Unix
(
0
,
0
)
.
UTC
()}
svc
:=
&
DashboardAggregationService
{
repo
:
repo
,
cfg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
,
IntervalSeconds
:
60
,
LookbackSeconds
:
120
,
Retention
:
config
.
DashboardAggregationRetentionConfig
{
UsageLogsDays
:
1
,
HourlyDays
:
1
,
DailyDays
:
1
,
},
},
}
svc
.
runScheduledAggregation
()
require
.
Equal
(
t
,
1
,
repo
.
aggregateCalls
)
require
.
False
(
t
,
repo
.
lastEnd
.
IsZero
())
require
.
Equal
(
t
,
truncateToDayUTC
(
repo
.
lastEnd
.
AddDate
(
0
,
0
,
-
1
)),
repo
.
lastStart
)
}
func
TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardAggregationRepoTestStub
{
cleanupAggregatesErr
:
errors
.
New
(
"清理失败"
)}
svc
:=
&
DashboardAggregationService
{
repo
:
repo
,
cfg
:
config
.
DashboardAggregationConfig
{
Retention
:
config
.
DashboardAggregationRetentionConfig
{
UsageLogsDays
:
1
,
HourlyDays
:
1
,
DailyDays
:
1
,
},
},
}
svc
.
maybeCleanupRetention
(
context
.
Background
(),
time
.
Now
()
.
UTC
())
require
.
Nil
(
t
,
svc
.
lastRetentionCleanup
.
Load
())
}
func
TestDashboardAggregationService_TriggerBackfill_TooLarge
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardAggregationRepoTestStub
{}
svc
:=
&
DashboardAggregationService
{
repo
:
repo
,
cfg
:
config
.
DashboardAggregationConfig
{
BackfillEnabled
:
true
,
BackfillMaxDays
:
1
,
},
}
start
:=
time
.
Now
()
.
AddDate
(
0
,
0
,
-
3
)
end
:=
time
.
Now
()
err
:=
svc
.
TriggerBackfill
(
start
,
end
)
require
.
ErrorIs
(
t
,
err
,
ErrDashboardBackfillTooLarge
)
require
.
Equal
(
t
,
0
,
repo
.
aggregateCalls
)
}
backend/internal/service/dashboard_service.go
View file @
b9b4db3d
...
...
@@ -2,47 +2,307 @@ package service
import
(
"context"
"encoding/json"
"errors"
"fmt"
"log"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
// DashboardService provides aggregated statistics for admin dashboard.
const
(
defaultDashboardStatsFreshTTL
=
15
*
time
.
Second
defaultDashboardStatsCacheTTL
=
30
*
time
.
Second
defaultDashboardStatsRefreshTimeout
=
30
*
time
.
Second
)
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
var
ErrDashboardStatsCacheMiss
=
errors
.
New
(
"仪表盘缓存未命中"
)
// DashboardStatsCache 定义仪表盘统计缓存接口。
type
DashboardStatsCache
interface
{
GetDashboardStats
(
ctx
context
.
Context
)
(
string
,
error
)
SetDashboardStats
(
ctx
context
.
Context
,
data
string
,
ttl
time
.
Duration
)
error
DeleteDashboardStats
(
ctx
context
.
Context
)
error
}
type
dashboardStatsRangeFetcher
interface
{
GetDashboardStatsWithRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
(
*
usagestats
.
DashboardStats
,
error
)
}
type
dashboardStatsCacheEntry
struct
{
Stats
*
usagestats
.
DashboardStats
`json:"stats"`
UpdatedAt
int64
`json:"updated_at"`
}
// DashboardService 提供管理员仪表盘统计服务。
type
DashboardService
struct
{
usageRepo
UsageLogRepository
usageRepo
UsageLogRepository
aggRepo
DashboardAggregationRepository
cache
DashboardStatsCache
cacheFreshTTL
time
.
Duration
cacheTTL
time
.
Duration
refreshTimeout
time
.
Duration
refreshing
int32
aggEnabled
bool
aggInterval
time
.
Duration
aggLookback
time
.
Duration
aggUsageDays
int
}
func
NewDashboardService
(
usageRepo
UsageLogRepository
)
*
DashboardService
{
func
NewDashboardService
(
usageRepo
UsageLogRepository
,
aggRepo
DashboardAggregationRepository
,
cache
DashboardStatsCache
,
cfg
*
config
.
Config
)
*
DashboardService
{
freshTTL
:=
defaultDashboardStatsFreshTTL
cacheTTL
:=
defaultDashboardStatsCacheTTL
refreshTimeout
:=
defaultDashboardStatsRefreshTimeout
aggEnabled
:=
true
aggInterval
:=
time
.
Minute
aggLookback
:=
2
*
time
.
Minute
aggUsageDays
:=
90
if
cfg
!=
nil
{
if
!
cfg
.
Dashboard
.
Enabled
{
cache
=
nil
}
if
cfg
.
Dashboard
.
StatsFreshTTLSeconds
>
0
{
freshTTL
=
time
.
Duration
(
cfg
.
Dashboard
.
StatsFreshTTLSeconds
)
*
time
.
Second
}
if
cfg
.
Dashboard
.
StatsTTLSeconds
>
0
{
cacheTTL
=
time
.
Duration
(
cfg
.
Dashboard
.
StatsTTLSeconds
)
*
time
.
Second
}
if
cfg
.
Dashboard
.
StatsRefreshTimeoutSeconds
>
0
{
refreshTimeout
=
time
.
Duration
(
cfg
.
Dashboard
.
StatsRefreshTimeoutSeconds
)
*
time
.
Second
}
aggEnabled
=
cfg
.
DashboardAgg
.
Enabled
if
cfg
.
DashboardAgg
.
IntervalSeconds
>
0
{
aggInterval
=
time
.
Duration
(
cfg
.
DashboardAgg
.
IntervalSeconds
)
*
time
.
Second
}
if
cfg
.
DashboardAgg
.
LookbackSeconds
>
0
{
aggLookback
=
time
.
Duration
(
cfg
.
DashboardAgg
.
LookbackSeconds
)
*
time
.
Second
}
if
cfg
.
DashboardAgg
.
Retention
.
UsageLogsDays
>
0
{
aggUsageDays
=
cfg
.
DashboardAgg
.
Retention
.
UsageLogsDays
}
}
if
aggRepo
==
nil
{
aggEnabled
=
false
}
return
&
DashboardService
{
usageRepo
:
usageRepo
,
usageRepo
:
usageRepo
,
aggRepo
:
aggRepo
,
cache
:
cache
,
cacheFreshTTL
:
freshTTL
,
cacheTTL
:
cacheTTL
,
refreshTimeout
:
refreshTimeout
,
aggEnabled
:
aggEnabled
,
aggInterval
:
aggInterval
,
aggLookback
:
aggLookback
,
aggUsageDays
:
aggUsageDays
,
}
}
func
(
s
*
DashboardService
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetDashboardStats
(
ctx
)
if
s
.
cache
!=
nil
{
cached
,
fresh
,
err
:=
s
.
getCachedDashboardStats
(
ctx
)
if
err
==
nil
&&
cached
!=
nil
{
s
.
refreshAggregationStaleness
(
cached
)
if
!
fresh
{
s
.
refreshDashboardStatsAsync
()
}
return
cached
,
nil
}
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
ErrDashboardStatsCacheMiss
)
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存读取失败: %v"
,
err
)
}
}
stats
,
err
:=
s
.
refreshDashboardStats
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get dashboard stats: %w"
,
err
)
}
return
stats
,
nil
}
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
)
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get usage trend with filters: %w"
,
err
)
}
return
trend
,
nil
}
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
int64
)
([]
usagestats
.
ModelStat
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
0
)
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
stream
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get model stats with filters: %w"
,
err
)
}
return
stats
,
nil
}
func
(
s
*
DashboardService
)
getCachedDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
bool
,
error
)
{
data
,
err
:=
s
.
cache
.
GetDashboardStats
(
ctx
)
if
err
!=
nil
{
return
nil
,
false
,
err
}
var
entry
dashboardStatsCacheEntry
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
entry
);
err
!=
nil
{
s
.
evictDashboardStatsCache
(
err
)
return
nil
,
false
,
ErrDashboardStatsCacheMiss
}
if
entry
.
Stats
==
nil
{
s
.
evictDashboardStatsCache
(
errors
.
New
(
"仪表盘缓存缺少统计数据"
))
return
nil
,
false
,
ErrDashboardStatsCacheMiss
}
age
:=
time
.
Since
(
time
.
Unix
(
entry
.
UpdatedAt
,
0
))
return
entry
.
Stats
,
age
<=
s
.
cacheFreshTTL
,
nil
}
func
(
s
*
DashboardService
)
refreshDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
{
stats
,
err
:=
s
.
fetchDashboardStats
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
s
.
applyAggregationStatus
(
ctx
,
stats
)
cacheCtx
,
cancel
:=
s
.
cacheOperationContext
()
defer
cancel
()
s
.
saveDashboardStatsCache
(
cacheCtx
,
stats
)
return
stats
,
nil
}
func
(
s
*
DashboardService
)
refreshDashboardStatsAsync
()
{
if
s
.
cache
==
nil
{
return
}
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
refreshing
,
0
,
1
)
{
return
}
go
func
()
{
defer
atomic
.
StoreInt32
(
&
s
.
refreshing
,
0
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
s
.
refreshTimeout
)
defer
cancel
()
stats
,
err
:=
s
.
fetchDashboardStats
(
ctx
)
if
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存异步刷新失败: %v"
,
err
)
return
}
s
.
applyAggregationStatus
(
ctx
,
stats
)
cacheCtx
,
cancel
:=
s
.
cacheOperationContext
()
defer
cancel
()
s
.
saveDashboardStatsCache
(
cacheCtx
,
stats
)
}()
}
func
(
s
*
DashboardService
)
fetchDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
{
if
!
s
.
aggEnabled
{
if
fetcher
,
ok
:=
s
.
usageRepo
.
(
dashboardStatsRangeFetcher
);
ok
{
now
:=
time
.
Now
()
.
UTC
()
start
:=
truncateToDayUTC
(
now
.
AddDate
(
0
,
0
,
-
s
.
aggUsageDays
))
return
fetcher
.
GetDashboardStatsWithRange
(
ctx
,
start
,
now
)
}
}
return
s
.
usageRepo
.
GetDashboardStats
(
ctx
)
}
func
(
s
*
DashboardService
)
saveDashboardStatsCache
(
ctx
context
.
Context
,
stats
*
usagestats
.
DashboardStats
)
{
if
s
.
cache
==
nil
||
stats
==
nil
{
return
}
entry
:=
dashboardStatsCacheEntry
{
Stats
:
stats
,
UpdatedAt
:
time
.
Now
()
.
Unix
(),
}
data
,
err
:=
json
.
Marshal
(
entry
)
if
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存序列化失败: %v"
,
err
)
return
}
if
err
:=
s
.
cache
.
SetDashboardStats
(
ctx
,
string
(
data
),
s
.
cacheTTL
);
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存写入失败: %v"
,
err
)
}
}
func
(
s
*
DashboardService
)
evictDashboardStatsCache
(
reason
error
)
{
if
s
.
cache
==
nil
{
return
}
cacheCtx
,
cancel
:=
s
.
cacheOperationContext
()
defer
cancel
()
if
err
:=
s
.
cache
.
DeleteDashboardStats
(
cacheCtx
);
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存清理失败: %v"
,
err
)
}
if
reason
!=
nil
{
log
.
Printf
(
"[Dashboard] 仪表盘缓存异常,已清理: %v"
,
reason
)
}
}
func
(
s
*
DashboardService
)
cacheOperationContext
()
(
context
.
Context
,
context
.
CancelFunc
)
{
return
context
.
WithTimeout
(
context
.
Background
(),
s
.
refreshTimeout
)
}
func
(
s
*
DashboardService
)
applyAggregationStatus
(
ctx
context
.
Context
,
stats
*
usagestats
.
DashboardStats
)
{
if
stats
==
nil
{
return
}
updatedAt
:=
s
.
fetchAggregationUpdatedAt
(
ctx
)
stats
.
StatsUpdatedAt
=
updatedAt
.
UTC
()
.
Format
(
time
.
RFC3339
)
stats
.
StatsStale
=
s
.
isAggregationStale
(
updatedAt
,
time
.
Now
()
.
UTC
())
}
func
(
s
*
DashboardService
)
refreshAggregationStaleness
(
stats
*
usagestats
.
DashboardStats
)
{
if
stats
==
nil
{
return
}
updatedAt
:=
parseStatsUpdatedAt
(
stats
.
StatsUpdatedAt
)
stats
.
StatsStale
=
s
.
isAggregationStale
(
updatedAt
,
time
.
Now
()
.
UTC
())
}
func
(
s
*
DashboardService
)
fetchAggregationUpdatedAt
(
ctx
context
.
Context
)
time
.
Time
{
if
s
.
aggRepo
==
nil
{
return
time
.
Unix
(
0
,
0
)
.
UTC
()
}
updatedAt
,
err
:=
s
.
aggRepo
.
GetAggregationWatermark
(
ctx
)
if
err
!=
nil
{
log
.
Printf
(
"[Dashboard] 读取聚合水位失败: %v"
,
err
)
return
time
.
Unix
(
0
,
0
)
.
UTC
()
}
if
updatedAt
.
IsZero
()
{
return
time
.
Unix
(
0
,
0
)
.
UTC
()
}
return
updatedAt
.
UTC
()
}
func
(
s
*
DashboardService
)
isAggregationStale
(
updatedAt
,
now
time
.
Time
)
bool
{
if
!
s
.
aggEnabled
{
return
true
}
epoch
:=
time
.
Unix
(
0
,
0
)
.
UTC
()
if
!
updatedAt
.
After
(
epoch
)
{
return
true
}
threshold
:=
s
.
aggInterval
+
s
.
aggLookback
return
now
.
Sub
(
updatedAt
)
>
threshold
}
func
parseStatsUpdatedAt
(
raw
string
)
time
.
Time
{
if
raw
==
""
{
return
time
.
Unix
(
0
,
0
)
.
UTC
()
}
parsed
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
raw
)
if
err
!=
nil
{
return
time
.
Unix
(
0
,
0
)
.
UTC
()
}
return
parsed
.
UTC
()
}
func
(
s
*
DashboardService
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
{
trend
,
err
:=
s
.
usageRepo
.
GetAPIKeyUsageTrend
(
ctx
,
startTime
,
endTime
,
granularity
,
limit
)
if
err
!=
nil
{
...
...
backend/internal/service/dashboard_service_test.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"encoding/json"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require"
)
type
usageRepoStub
struct
{
UsageLogRepository
stats
*
usagestats
.
DashboardStats
rangeStats
*
usagestats
.
DashboardStats
err
error
rangeErr
error
calls
int32
rangeCalls
int32
rangeStart
time
.
Time
rangeEnd
time
.
Time
onCall
chan
struct
{}
}
func
(
s
*
usageRepoStub
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
{
atomic
.
AddInt32
(
&
s
.
calls
,
1
)
if
s
.
onCall
!=
nil
{
select
{
case
s
.
onCall
<-
struct
{}{}
:
default
:
}
}
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
return
s
.
stats
,
nil
}
func
(
s
*
usageRepoStub
)
GetDashboardStatsWithRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
(
*
usagestats
.
DashboardStats
,
error
)
{
atomic
.
AddInt32
(
&
s
.
rangeCalls
,
1
)
s
.
rangeStart
=
start
s
.
rangeEnd
=
end
if
s
.
rangeErr
!=
nil
{
return
nil
,
s
.
rangeErr
}
if
s
.
rangeStats
!=
nil
{
return
s
.
rangeStats
,
nil
}
return
s
.
stats
,
nil
}
type
dashboardCacheStub
struct
{
get
func
(
ctx
context
.
Context
)
(
string
,
error
)
set
func
(
ctx
context
.
Context
,
data
string
,
ttl
time
.
Duration
)
error
del
func
(
ctx
context
.
Context
)
error
getCalls
int32
setCalls
int32
delCalls
int32
lastSetMu
sync
.
Mutex
lastSet
string
}
func
(
c
*
dashboardCacheStub
)
GetDashboardStats
(
ctx
context
.
Context
)
(
string
,
error
)
{
atomic
.
AddInt32
(
&
c
.
getCalls
,
1
)
if
c
.
get
!=
nil
{
return
c
.
get
(
ctx
)
}
return
""
,
ErrDashboardStatsCacheMiss
}
func
(
c
*
dashboardCacheStub
)
SetDashboardStats
(
ctx
context
.
Context
,
data
string
,
ttl
time
.
Duration
)
error
{
atomic
.
AddInt32
(
&
c
.
setCalls
,
1
)
c
.
lastSetMu
.
Lock
()
c
.
lastSet
=
data
c
.
lastSetMu
.
Unlock
()
if
c
.
set
!=
nil
{
return
c
.
set
(
ctx
,
data
,
ttl
)
}
return
nil
}
func
(
c
*
dashboardCacheStub
)
DeleteDashboardStats
(
ctx
context
.
Context
)
error
{
atomic
.
AddInt32
(
&
c
.
delCalls
,
1
)
if
c
.
del
!=
nil
{
return
c
.
del
(
ctx
)
}
return
nil
}
type
dashboardAggregationRepoStub
struct
{
watermark
time
.
Time
err
error
}
func
(
s
*
dashboardAggregationRepoStub
)
AggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
if
s
.
err
!=
nil
{
return
time
.
Time
{},
s
.
err
}
return
s
.
watermark
,
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
UpdateAggregationWatermark
(
ctx
context
.
Context
,
aggregatedAt
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
CleanupAggregates
(
ctx
context
.
Context
,
hourlyCutoff
,
dailyCutoff
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
CleanupUsageLogs
(
ctx
context
.
Context
,
cutoff
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
EnsureUsageLogsPartitions
(
ctx
context
.
Context
,
now
time
.
Time
)
error
{
return
nil
}
func
(
c
*
dashboardCacheStub
)
readLastEntry
(
t
*
testing
.
T
)
dashboardStatsCacheEntry
{
t
.
Helper
()
c
.
lastSetMu
.
Lock
()
data
:=
c
.
lastSet
c
.
lastSetMu
.
Unlock
()
var
entry
dashboardStatsCacheEntry
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
entry
)
require
.
NoError
(
t
,
err
)
return
entry
}
func
TestDashboardService_CacheHitFresh
(
t
*
testing
.
T
)
{
stats
:=
&
usagestats
.
DashboardStats
{
TotalUsers
:
10
,
StatsUpdatedAt
:
time
.
Unix
(
0
,
0
)
.
UTC
()
.
Format
(
time
.
RFC3339
),
StatsStale
:
true
,
}
entry
:=
dashboardStatsCacheEntry
{
Stats
:
stats
,
UpdatedAt
:
time
.
Now
()
.
Unix
(),
}
payload
,
err
:=
json
.
Marshal
(
entry
)
require
.
NoError
(
t
,
err
)
cache
:=
&
dashboardCacheStub
{
get
:
func
(
ctx
context
.
Context
)
(
string
,
error
)
{
return
string
(
payload
),
nil
},
}
repo
:=
&
usageRepoStub
{
stats
:
&
usagestats
.
DashboardStats
{
TotalUsers
:
99
},
}
aggRepo
:=
&
dashboardAggregationRepoStub
{
watermark
:
time
.
Unix
(
0
,
0
)
.
UTC
()}
cfg
:=
&
config
.
Config
{
Dashboard
:
config
.
DashboardCacheConfig
{
Enabled
:
true
},
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
,
},
}
svc
:=
NewDashboardService
(
repo
,
aggRepo
,
cache
,
cfg
)
got
,
err
:=
svc
.
GetDashboardStats
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
stats
,
got
)
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
repo
.
calls
))
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
getCalls
))
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
setCalls
))
}
func
TestDashboardService_CacheMiss_StoresCache
(
t
*
testing
.
T
)
{
stats
:=
&
usagestats
.
DashboardStats
{
TotalUsers
:
7
,
StatsUpdatedAt
:
time
.
Unix
(
0
,
0
)
.
UTC
()
.
Format
(
time
.
RFC3339
),
StatsStale
:
true
,
}
cache
:=
&
dashboardCacheStub
{
get
:
func
(
ctx
context
.
Context
)
(
string
,
error
)
{
return
""
,
ErrDashboardStatsCacheMiss
},
}
repo
:=
&
usageRepoStub
{
stats
:
stats
}
aggRepo
:=
&
dashboardAggregationRepoStub
{
watermark
:
time
.
Unix
(
0
,
0
)
.
UTC
()}
cfg
:=
&
config
.
Config
{
Dashboard
:
config
.
DashboardCacheConfig
{
Enabled
:
true
},
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
,
},
}
svc
:=
NewDashboardService
(
repo
,
aggRepo
,
cache
,
cfg
)
got
,
err
:=
svc
.
GetDashboardStats
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
stats
,
got
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
repo
.
calls
))
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
getCalls
))
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
setCalls
))
entry
:=
cache
.
readLastEntry
(
t
)
require
.
Equal
(
t
,
stats
,
entry
.
Stats
)
require
.
WithinDuration
(
t
,
time
.
Now
(),
time
.
Unix
(
entry
.
UpdatedAt
,
0
),
time
.
Second
)
}
func
TestDashboardService_CacheDisabled_SkipsCache
(
t
*
testing
.
T
)
{
stats
:=
&
usagestats
.
DashboardStats
{
TotalUsers
:
3
,
StatsUpdatedAt
:
time
.
Unix
(
0
,
0
)
.
UTC
()
.
Format
(
time
.
RFC3339
),
StatsStale
:
true
,
}
cache
:=
&
dashboardCacheStub
{
get
:
func
(
ctx
context
.
Context
)
(
string
,
error
)
{
return
""
,
nil
},
}
repo
:=
&
usageRepoStub
{
stats
:
stats
}
aggRepo
:=
&
dashboardAggregationRepoStub
{
watermark
:
time
.
Unix
(
0
,
0
)
.
UTC
()}
cfg
:=
&
config
.
Config
{
Dashboard
:
config
.
DashboardCacheConfig
{
Enabled
:
false
},
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
,
},
}
svc
:=
NewDashboardService
(
repo
,
aggRepo
,
cache
,
cfg
)
got
,
err
:=
svc
.
GetDashboardStats
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
stats
,
got
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
repo
.
calls
))
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
getCalls
))
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
setCalls
))
}
func
TestDashboardService_CacheHitStale_TriggersAsyncRefresh
(
t
*
testing
.
T
)
{
staleStats
:=
&
usagestats
.
DashboardStats
{
TotalUsers
:
11
,
StatsUpdatedAt
:
time
.
Unix
(
0
,
0
)
.
UTC
()
.
Format
(
time
.
RFC3339
),
StatsStale
:
true
,
}
entry
:=
dashboardStatsCacheEntry
{
Stats
:
staleStats
,
UpdatedAt
:
time
.
Now
()
.
Add
(
-
defaultDashboardStatsFreshTTL
*
2
)
.
Unix
(),
}
payload
,
err
:=
json
.
Marshal
(
entry
)
require
.
NoError
(
t
,
err
)
cache
:=
&
dashboardCacheStub
{
get
:
func
(
ctx
context
.
Context
)
(
string
,
error
)
{
return
string
(
payload
),
nil
},
}
refreshCh
:=
make
(
chan
struct
{},
1
)
repo
:=
&
usageRepoStub
{
stats
:
&
usagestats
.
DashboardStats
{
TotalUsers
:
22
},
onCall
:
refreshCh
,
}
aggRepo
:=
&
dashboardAggregationRepoStub
{
watermark
:
time
.
Unix
(
0
,
0
)
.
UTC
()}
cfg
:=
&
config
.
Config
{
Dashboard
:
config
.
DashboardCacheConfig
{
Enabled
:
true
},
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
,
},
}
svc
:=
NewDashboardService
(
repo
,
aggRepo
,
cache
,
cfg
)
got
,
err
:=
svc
.
GetDashboardStats
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
staleStats
,
got
)
select
{
case
<-
refreshCh
:
case
<-
time
.
After
(
1
*
time
.
Second
)
:
t
.
Fatal
(
"等待异步刷新超时"
)
}
require
.
Eventually
(
t
,
func
()
bool
{
return
atomic
.
LoadInt32
(
&
cache
.
setCalls
)
>=
1
},
1
*
time
.
Second
,
10
*
time
.
Millisecond
)
}
func
TestDashboardService_CacheParseError_EvictsAndRefetches
(
t
*
testing
.
T
)
{
cache
:=
&
dashboardCacheStub
{
get
:
func
(
ctx
context
.
Context
)
(
string
,
error
)
{
return
"not-json"
,
nil
},
}
stats
:=
&
usagestats
.
DashboardStats
{
TotalUsers
:
9
}
repo
:=
&
usageRepoStub
{
stats
:
stats
}
aggRepo
:=
&
dashboardAggregationRepoStub
{
watermark
:
time
.
Unix
(
0
,
0
)
.
UTC
()}
cfg
:=
&
config
.
Config
{
Dashboard
:
config
.
DashboardCacheConfig
{
Enabled
:
true
},
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
,
},
}
svc
:=
NewDashboardService
(
repo
,
aggRepo
,
cache
,
cfg
)
got
,
err
:=
svc
.
GetDashboardStats
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
stats
,
got
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
delCalls
))
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
repo
.
calls
))
}
func
TestDashboardService_CacheParseError_RepoFailure
(
t
*
testing
.
T
)
{
cache
:=
&
dashboardCacheStub
{
get
:
func
(
ctx
context
.
Context
)
(
string
,
error
)
{
return
"not-json"
,
nil
},
}
repo
:=
&
usageRepoStub
{
err
:
errors
.
New
(
"db down"
)}
aggRepo
:=
&
dashboardAggregationRepoStub
{
watermark
:
time
.
Unix
(
0
,
0
)
.
UTC
()}
cfg
:=
&
config
.
Config
{
Dashboard
:
config
.
DashboardCacheConfig
{
Enabled
:
true
},
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
,
},
}
svc
:=
NewDashboardService
(
repo
,
aggRepo
,
cache
,
cfg
)
_
,
err
:=
svc
.
GetDashboardStats
(
context
.
Background
())
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
delCalls
))
}
func
TestDashboardService_StatsUpdatedAtEpochWhenMissing
(
t
*
testing
.
T
)
{
stats
:=
&
usagestats
.
DashboardStats
{}
repo
:=
&
usageRepoStub
{
stats
:
stats
}
aggRepo
:=
&
dashboardAggregationRepoStub
{
watermark
:
time
.
Unix
(
0
,
0
)
.
UTC
()}
cfg
:=
&
config
.
Config
{
Dashboard
:
config
.
DashboardCacheConfig
{
Enabled
:
false
}}
svc
:=
NewDashboardService
(
repo
,
aggRepo
,
nil
,
cfg
)
got
,
err
:=
svc
.
GetDashboardStats
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"1970-01-01T00:00:00Z"
,
got
.
StatsUpdatedAt
)
require
.
True
(
t
,
got
.
StatsStale
)
}
func
TestDashboardService_StatsStaleFalseWhenFresh
(
t
*
testing
.
T
)
{
aggNow
:=
time
.
Now
()
.
UTC
()
.
Truncate
(
time
.
Second
)
stats
:=
&
usagestats
.
DashboardStats
{}
repo
:=
&
usageRepoStub
{
stats
:
stats
}
aggRepo
:=
&
dashboardAggregationRepoStub
{
watermark
:
aggNow
}
cfg
:=
&
config
.
Config
{
Dashboard
:
config
.
DashboardCacheConfig
{
Enabled
:
false
},
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
,
IntervalSeconds
:
60
,
LookbackSeconds
:
120
,
},
}
svc
:=
NewDashboardService
(
repo
,
aggRepo
,
nil
,
cfg
)
got
,
err
:=
svc
.
GetDashboardStats
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
aggNow
.
Format
(
time
.
RFC3339
),
got
.
StatsUpdatedAt
)
require
.
False
(
t
,
got
.
StatsStale
)
}
func
TestDashboardService_AggDisabled_UsesUsageLogsFallback
(
t
*
testing
.
T
)
{
expected
:=
&
usagestats
.
DashboardStats
{
TotalUsers
:
42
}
repo
:=
&
usageRepoStub
{
rangeStats
:
expected
,
err
:
errors
.
New
(
"should not call aggregated stats"
),
}
cfg
:=
&
config
.
Config
{
Dashboard
:
config
.
DashboardCacheConfig
{
Enabled
:
false
},
DashboardAgg
:
config
.
DashboardAggregationConfig
{
Enabled
:
false
,
Retention
:
config
.
DashboardAggregationRetentionConfig
{
UsageLogsDays
:
7
,
},
},
}
svc
:=
NewDashboardService
(
repo
,
nil
,
nil
,
cfg
)
got
,
err
:=
svc
.
GetDashboardStats
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
42
),
got
.
TotalUsers
)
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
repo
.
calls
))
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
repo
.
rangeCalls
))
require
.
False
(
t
,
repo
.
rangeEnd
.
IsZero
())
require
.
Equal
(
t
,
truncateToDayUTC
(
repo
.
rangeEnd
.
AddDate
(
0
,
0
,
-
7
)),
repo
.
rangeStart
)
}
backend/internal/service/domain_constants.go
View file @
b9b4db3d
...
...
@@ -38,6 +38,12 @@ const (
RedeemTypeSubscription
=
"subscription"
)
// PromoCode status constants
const
(
PromoCodeStatusActive
=
"active"
PromoCodeStatusDisabled
=
"disabled"
)
// Admin adjustment type constants
const
(
AdjustmentTypeAdminBalance
=
"admin_balance"
// 管理员调整余额
...
...
@@ -57,6 +63,9 @@ const (
SubscriptionStatusSuspended
=
"suspended"
)
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
const
LinuxDoConnectSyntheticEmailDomain
=
"@linuxdo-connect.invalid"
// Setting keys
const
(
// 注册设置
...
...
@@ -77,6 +86,12 @@ const (
SettingKeyTurnstileSiteKey
=
"turnstile_site_key"
// Turnstile Site Key
SettingKeyTurnstileSecretKey
=
"turnstile_secret_key"
// Turnstile Secret Key
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientSecret
=
"linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
// OEM设置
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
...
...
@@ -84,6 +99,7 @@ const (
SettingKeyAPIBaseURL
=
"api_base_url"
// API端点地址(用于客户端配置和导入)
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
SettingKeyDocURL
=
"doc_url"
// 文档链接
SettingKeyHomeContent
=
"home_content"
// 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
// 默认配置
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
...
...
@@ -106,16 +122,38 @@ const (
SettingKeyEnableIdentityPatch
=
"enable_identity_patch"
SettingKeyIdentityPatchPrompt
=
"identity_patch_prompt"
// LinuxDo Connect OAuth 登录(终端用户 SSO)
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientSecret
=
"linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
)
// =========================
// Ops Monitoring (vNext)
// =========================
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
const
LinuxDoConnectSyntheticEmailDomain
=
"@linuxdo-connect.invalid"
// SettingKeyOpsMonitoringEnabled is a DB-backed soft switch to enable/disable ops module at runtime.
SettingKeyOpsMonitoringEnabled
=
"ops_monitoring_enabled"
// SettingKeyOpsRealtimeMonitoringEnabled controls realtime features (e.g. WS/QPS push).
SettingKeyOpsRealtimeMonitoringEnabled
=
"ops_realtime_monitoring_enabled"
// SettingKeyOpsQueryModeDefault controls the default query mode for ops dashboard (auto/raw/preagg).
SettingKeyOpsQueryModeDefault
=
"ops_query_mode_default"
// SettingKeyOpsEmailNotificationConfig stores JSON config for ops email notifications.
SettingKeyOpsEmailNotificationConfig
=
"ops_email_notification_config"
// SettingKeyOpsAlertRuntimeSettings stores JSON config for ops alert evaluator runtime settings.
SettingKeyOpsAlertRuntimeSettings
=
"ops_alert_runtime_settings"
// SettingKeyOpsMetricsIntervalSeconds controls the ops metrics collector interval (>=60).
SettingKeyOpsMetricsIntervalSeconds
=
"ops_metrics_interval_seconds"
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
SettingKeyOpsAdvancedSettings
=
"ops_advanced_settings"
// =========================
// Stream Timeout Handling
// =========================
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings
=
"stream_timeout_settings"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const
AdminAPIKeyPrefix
=
"admin-"
backend/internal/service/gateway_multiplatform_test.go
View file @
b9b4db3d
...
...
@@ -9,6 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
...
...
@@ -23,9 +24,11 @@ type mockAccountRepoForPlatform struct {
accounts
[]
Account
accountsByID
map
[
int64
]
*
Account
listPlatformFunc
func
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
getByIDCalls
int
}
func
(
m
*
mockAccountRepoForPlatform
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
m
.
getByIDCalls
++
if
acc
,
ok
:=
m
.
accountsByID
[
id
];
ok
{
return
acc
,
nil
}
...
...
@@ -142,6 +145,9 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func
(
m
*
mockAccountRepoForPlatform
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
...
...
@@ -157,6 +163,9 @@ func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int6
func
(
m
*
mockAccountRepoForPlatform
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
...
...
@@ -194,6 +203,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return
nil
}
type
mockGroupRepoForGateway
struct
{
groups
map
[
int64
]
*
Group
getByIDCalls
int
getByIDLiteCalls
int
}
func
(
m
*
mockGroupRepoForGateway
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
m
*
mockGroupRepoForGateway
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDLiteCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
m
*
mockGroupRepoForGateway
)
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListActive
(
ctx
context
.
Context
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
}
...
...
@@ -900,6 +959,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
return
m
.
accountWaitCounts
[
accountID
],
nil
}
type
mockConcurrencyCache
struct
{
acquireAccountCalls
int
loadBatchCalls
int
}
func
(
m
*
mockConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
m
.
acquireAccountCalls
++
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockConcurrencyCache
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockConcurrencyCache
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
ReleaseUserSlot
(
ctx
context
.
Context
,
userID
int64
,
requestID
string
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetUserConcurrency
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockConcurrencyCache
)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
m
.
loadBatchCalls
++
result
:=
make
(
map
[
int64
]
*
AccountLoadInfo
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
result
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
CurrentConcurrency
:
0
,
WaitingCount
:
0
,
LoadRate
:
0
,
}
}
return
result
,
nil
}
func
(
m
*
mockConcurrencyCache
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func
TestGatewayService_SelectAccountWithLoadAwareness
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
...
...
@@ -928,13 +1055,67 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
// No concurrency service
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
})
t
.
Run
(
"模型路由-无ConcurrencyService也生效"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
sessionHash
:=
"sticky"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
sessionHash
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-a"
:
{
1
},
"claude-b"
:
{
2
},
},
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
// legacy path
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
sessionHash
,
"claude-b"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"切换到 claude-b 时应按模型路由切换账号"
)
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
sessionHash
],
"粘性绑定应更新为路由选择的账号"
)
})
t
.
Run
(
"无ConcurrencyService-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
...
...
@@ -959,7 +1140,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
...
...
@@ -991,13 +1172,85 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"不应选择被排除的账号"
)
})
t
.
Run
(
"粘性命中-不调用GetByID"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
0
,
repo
.
getByIDCalls
,
"粘性命中不应调用GetByID"
)
require
.
Equal
(
t
,
0
,
concurrencyCache
.
loadBatchCalls
,
"粘性命中应在负载批量查询前返回"
)
})
t
.
Run
(
"粘性账号不在候选集-回退负载感知选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"粘性账号不在候选集时应回退到可用账号"
)
require
.
Equal
(
t
,
0
,
repo
.
getByIDCalls
,
"粘性账号缺失不应回退到GetByID"
)
require
.
Equal
(
t
,
1
,
concurrencyCache
.
loadBatchCalls
,
"应继续进行负载批量查询"
)
})
t
.
Run
(
"无可用账号-返回错误"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{},
...
...
@@ -1016,9 +1269,264 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
t
.
Run
(
"过滤不可调度账号-限流账号被跳过"
,
func
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
resetAt
:=
now
.
Add
(
10
*
time
.
Minute
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
RateLimitResetAt
:
&
resetAt
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应跳过限流账号,选择可用账号"
)
})
t
.
Run
(
"过滤不可调度账号-过载账号被跳过"
,
func
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
overloadUntil
:=
now
.
Add
(
10
*
time
.
Minute
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
OverloadUntil
:
&
overloadUntil
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应跳过过载账号,选择可用账号"
)
})
}
func
TestGatewayService_GroupResolution_ReusesContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
42
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_GroupResolution_IgnoresInvalidContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
42
)
ctxGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
ctxGroup
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_GroupContext_OverwritesInvalidContextGroup
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
42
)
invalidGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
hydratedGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
Group
,
invalidGroup
)
svc
:=
&
GatewayService
{}
ctx
=
svc
.
withGroupContext
(
ctx
,
hydratedGroup
)
got
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
)
require
.
True
(
t
,
ok
)
require
.
Same
(
t
,
hydratedGroup
,
got
)
}
func
TestGatewayService_GroupResolution_FallbackUsesLiteOnce
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
fallbackID
:=
int64
(
11
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
fallbackID
,
Hydrated
:
true
,
}
fallbackGroup
:=
&
Group
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
fallbackGroup
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
fallbackID
:=
int64
(
11
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
fallbackID
,
}
fallbackGroup
:=
&
Group
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
groupID
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
,
fallbackID
:
fallbackGroup
,
},
}
svc
:=
&
GatewayService
{
groupRepo
:
groupRepo
,
}
gotGroup
,
gotID
,
err
:=
svc
.
resolveGatewayGroup
(
ctx
,
&
groupID
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gotGroup
)
require
.
Nil
(
t
,
gotID
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
backend/internal/service/gateway_service.go
View file @
b9b4db3d
...
...
@@ -13,6 +13,7 @@ import (
"log"
mathrand
"math/rand"
"net/http"
"os"
"regexp"
"sort"
"strings"
...
...
@@ -39,6 +40,21 @@ const (
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
)
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_DEBUG_MODEL_ROUTING"
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
shortSessionHash
(
sessionHash
string
)
string
{
if
sessionHash
==
""
{
return
""
}
if
len
(
sessionHash
)
<=
8
{
return
sessionHash
}
return
sessionHash
[
:
8
]
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var
(
...
...
@@ -152,6 +168,7 @@ type GatewayService struct {
userSubRepo
UserSubscriptionRepository
cache
GatewayCache
cfg
*
config
.
Config
schedulerSnapshot
*
SchedulerSnapshotService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
...
...
@@ -159,6 +176,8 @@ type GatewayService struct {
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
concurrencyService
*
ConcurrencyService
claudeTokenProvider
*
ClaudeTokenProvider
sessionLimitCache
SessionLimitCache
// 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
}
// NewGatewayService creates a new GatewayService
...
...
@@ -170,6 +189,7 @@ func NewGatewayService(
userSubRepo
UserSubscriptionRepository
,
cache
GatewayCache
,
cfg
*
config
.
Config
,
schedulerSnapshot
*
SchedulerSnapshotService
,
concurrencyService
*
ConcurrencyService
,
billingService
*
BillingService
,
rateLimitService
*
RateLimitService
,
...
...
@@ -177,6 +197,8 @@ func NewGatewayService(
identityService
*
IdentityService
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
claudeTokenProvider
*
ClaudeTokenProvider
,
sessionLimitCache
SessionLimitCache
,
)
*
GatewayService
{
return
&
GatewayService
{
accountRepo
:
accountRepo
,
...
...
@@ -186,6 +208,7 @@ func NewGatewayService(
userSubRepo
:
userSubRepo
,
cache
:
cache
,
cfg
:
cfg
,
schedulerSnapshot
:
schedulerSnapshot
,
concurrencyService
:
concurrencyService
,
billingService
:
billingService
,
rateLimitService
:
rateLimitService
,
...
...
@@ -193,6 +216,8 @@ func NewGatewayService(
identityService
:
identityService
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
claudeTokenProvider
:
claudeTokenProvider
,
sessionLimitCache
:
sessionLimitCache
,
}
}
...
...
@@ -362,27 +387,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if
hasForcePlatform
&&
forcePlatform
!=
""
{
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
group
,
resolvedGroupID
,
err
:=
s
.
resolveGatewayGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
nil
,
err
}
groupID
=
resolvedGroupID
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
platform
=
group
.
Platform
// 检查 Claude Code 客户端限制
if
group
.
ClaudeCodeOnly
{
isClaudeCode
:=
IsClaudeCodeClient
(
ctx
)
if
!
isClaudeCode
{
// 非 Claude Code 客户端,检查是否有降级分组
if
group
.
FallbackGroupID
!=
nil
{
// 使用降级分组重新调度
fallbackGroupID
:=
*
group
.
FallbackGroupID
return
s
.
SelectAccountForModelWithExclusions
(
ctx
,
&
fallbackGroupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
}
// 无降级分组,拒绝访问
return
nil
,
ErrClaudeCodeOnly
}
}
}
else
{
// 无分组时只使用原生 anthropic 平台
platform
=
PlatformAnthropic
...
...
@@ -400,8 +411,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func
(
s
*
GatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
AccountSelectionResult
,
error
)
{
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
func
(
s
*
GatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
metadataUserID
string
)
(
*
AccountSelectionResult
,
error
)
{
cfg
:=
s
.
schedulingConfig
()
// 提取会话 UUID(用于会话数量限制)
sessionUUID
:=
extractSessionUUID
(
metadataUserID
)
var
stickyAccountID
int64
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
);
err
==
nil
{
...
...
@@ -410,10 +425,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
group
,
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
err
}
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
if
s
.
debugModelRoutingEnabled
()
&&
requestedModel
!=
""
{
groupPlatform
:=
""
if
group
!=
nil
{
groupPlatform
=
group
.
Platform
}
log
.
Printf
(
"[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v"
,
derefGroupID
(
groupID
),
groupPlatform
,
requestedModel
,
shortSessionHash
(
sessionHash
),
stickyAccountID
,
cfg
.
LoadBatchEnabled
,
s
.
concurrencyService
!=
nil
)
}
if
s
.
concurrencyService
==
nil
||
!
cfg
.
LoadBatchEnabled
{
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
...
...
@@ -453,11 +478,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
nil
}
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
)
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
,
group
)
if
err
!=
nil
{
return
nil
,
err
}
preferOAuth
:=
platform
==
PlatformGemini
if
s
.
debugModelRoutingEnabled
()
&&
platform
==
PlatformAnthropic
&&
requestedModel
!=
""
{
log
.
Printf
(
"[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
platform
)
}
accounts
,
useMixed
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
...
...
@@ -475,23 +503,242 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
excluded
}
// ============ Layer 1: 粘性会话优先 ============
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID
:=
make
(
map
[
int64
]
*
Account
,
len
(
accounts
))
for
i
:=
range
accounts
{
accountByID
[
accounts
[
i
]
.
ID
]
=
&
accounts
[
i
]
}
// 获取模型路由配置(仅 anthropic 平台)
var
routingAccountIDs
[]
int64
if
group
!=
nil
&&
requestedModel
!=
""
&&
group
.
Platform
==
PlatformAnthropic
{
routingAccountIDs
=
group
.
GetRoutingAccountIDs
(
requestedModel
)
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d"
,
group
.
ID
,
requestedModel
,
group
.
ModelRoutingEnabled
,
len
(
group
.
ModelRouting
),
routingAccountIDs
,
shortSessionHash
(
sessionHash
),
stickyAccountID
)
if
len
(
routingAccountIDs
)
==
0
&&
group
.
ModelRoutingEnabled
&&
len
(
group
.
ModelRouting
)
>
0
{
keys
:=
make
([]
string
,
0
,
len
(
group
.
ModelRouting
))
for
k
:=
range
group
.
ModelRouting
{
keys
=
append
(
keys
,
k
)
}
sort
.
Strings
(
keys
)
const
maxKeys
=
20
if
len
(
keys
)
>
maxKeys
{
keys
=
keys
[
:
maxKeys
]
}
log
.
Printf
(
"[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v"
,
group
.
ID
,
requestedModel
,
keys
)
}
}
}
// ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============
if
len
(
routingAccountIDs
)
>
0
&&
s
.
concurrencyService
!=
nil
{
// 1. 过滤出路由列表中可调度的账号
var
routingCandidates
[]
*
Account
var
filteredExcluded
,
filteredMissing
,
filteredUnsched
,
filteredPlatform
,
filteredModelScope
,
filteredModelMapping
,
filteredWindowCost
int
for
_
,
routingAccountID
:=
range
routingAccountIDs
{
if
isExcluded
(
routingAccountID
)
{
filteredExcluded
++
continue
}
account
,
ok
:=
accountByID
[
routingAccountID
]
if
!
ok
||
!
account
.
IsSchedulable
()
{
if
!
ok
{
filteredMissing
++
}
else
{
filteredUnsched
++
}
continue
}
if
!
s
.
isAccountAllowedForPlatform
(
account
,
platform
,
useMixed
)
{
filteredPlatform
++
continue
}
if
!
account
.
IsSchedulableForModel
(
requestedModel
)
{
filteredModelScope
++
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
account
,
requestedModel
)
{
filteredModelMapping
++
continue
}
// 窗口费用检查(非粘性会话路径)
if
!
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
false
)
{
filteredWindowCost
++
continue
}
routingCandidates
=
append
(
routingCandidates
,
account
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)"
,
derefGroupID
(
groupID
),
requestedModel
,
len
(
routingAccountIDs
),
len
(
routingCandidates
),
filteredExcluded
,
filteredMissing
,
filteredUnsched
,
filteredPlatform
,
filteredModelScope
,
filteredModelMapping
,
filteredWindowCost
)
}
if
len
(
routingCandidates
)
>
0
{
// 1.5. 在路由账号范围内检查粘性会话
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
stickyAccountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
stickyAccountID
>
0
&&
containsInt64
(
routingAccountIDs
,
stickyAccountID
)
&&
!
isExcluded
(
stickyAccountID
)
{
// 粘性账号在路由列表中,优先使用
if
stickyAccount
,
ok
:=
accountByID
[
stickyAccountID
];
ok
{
if
stickyAccount
.
IsSchedulable
()
&&
s
.
isAccountAllowedForPlatform
(
stickyAccount
,
platform
,
useMixed
)
&&
stickyAccount
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
stickyAccount
,
requestedModel
))
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
stickyAccount
,
true
)
{
// 粘性会话窗口费用检查
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
stickyAccountID
,
stickyAccount
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
stickyAccount
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位
// 继续到负载感知选择
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
stickyAccountID
)
}
return
&
AccountSelectionResult
{
Account
:
stickyAccount
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
stickyAccountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
Account
:
stickyAccount
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
stickyAccountID
,
MaxConcurrency
:
stickyAccount
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
}
}
}
// 2. 批量获取负载信息
routingLoads
:=
make
([]
AccountWithConcurrency
,
0
,
len
(
routingCandidates
))
for
_
,
acc
:=
range
routingCandidates
{
routingLoads
=
append
(
routingLoads
,
AccountWithConcurrency
{
ID
:
acc
.
ID
,
MaxConcurrency
:
acc
.
Concurrency
,
})
}
routingLoadMap
,
_
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
routingLoads
)
// 3. 按负载感知排序
type
accountWithLoad
struct
{
account
*
Account
loadInfo
*
AccountLoadInfo
}
var
routingAvailable
[]
accountWithLoad
for
_
,
acc
:=
range
routingCandidates
{
loadInfo
:=
routingLoadMap
[
acc
.
ID
]
if
loadInfo
==
nil
{
loadInfo
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
}
}
if
loadInfo
.
LoadRate
<
100
{
routingAvailable
=
append
(
routingAvailable
,
accountWithLoad
{
account
:
acc
,
loadInfo
:
loadInfo
})
}
}
if
len
(
routingAvailable
)
>
0
{
// 排序:优先级 > 负载率 > 最后使用时间
sort
.
SliceStable
(
routingAvailable
,
func
(
i
,
j
int
)
bool
{
a
,
b
:=
routingAvailable
[
i
],
routingAvailable
[
j
]
if
a
.
account
.
Priority
!=
b
.
account
.
Priority
{
return
a
.
account
.
Priority
<
b
.
account
.
Priority
}
if
a
.
loadInfo
.
LoadRate
!=
b
.
loadInfo
.
LoadRate
{
return
a
.
loadInfo
.
LoadRate
<
b
.
loadInfo
.
LoadRate
}
switch
{
case
a
.
account
.
LastUsedAt
==
nil
&&
b
.
account
.
LastUsedAt
!=
nil
:
return
true
case
a
.
account
.
LastUsedAt
!=
nil
&&
b
.
account
.
LastUsedAt
==
nil
:
return
false
case
a
.
account
.
LastUsedAt
==
nil
&&
b
.
account
.
LastUsedAt
==
nil
:
return
false
default
:
return
a
.
account
.
LastUsedAt
.
Before
(
*
b
.
account
.
LastUsedAt
)
}
})
// 4. 尝试获取槽位
for
_
,
item
:=
range
routingAvailable
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
item
.
account
.
ID
,
stickySessionTTL
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
item
.
account
.
ID
)
}
return
&
AccountSelectionResult
{
Account
:
item
.
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
acc
:=
routingAvailable
[
0
]
.
account
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
acc
.
ID
)
}
return
&
AccountSelectionResult
{
Account
:
acc
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
acc
.
ID
,
MaxConcurrency
:
acc
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
log
.
Printf
(
"[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection"
,
requestedModel
)
}
}
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
if
len
(
routingAccountIDs
)
==
0
&&
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
account
,
err
:=
s
.
account
Repo
.
GetByID
(
ctx
,
accountID
)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
,
ok
:=
account
ByID
[
accountID
]
if
ok
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
s
.
isAccountAllowedForPlatform
(
account
,
platform
,
useMixed
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
true
)
{
// 粘性会话窗口费用检查
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
...
...
@@ -517,6 +764,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
isExcluded
(
acc
.
ID
)
{
continue
}
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if
!
acc
.
IsSchedulable
()
{
continue
}
if
!
s
.
isAccountAllowedForPlatform
(
acc
,
platform
,
useMixed
)
{
continue
}
...
...
@@ -526,6 +779,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
// 窗口费用检查(非粘性会话路径)
if
!
s
.
isAccountSchedulableForWindowCost
(
ctx
,
acc
,
false
)
{
continue
}
candidates
=
append
(
candidates
,
acc
)
}
...
...
@@ -543,7 +800,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap
,
err
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
accountLoads
)
if
err
!=
nil
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
);
ok
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
,
sessionUUID
);
ok
{
return
result
,
nil
}
}
else
{
...
...
@@ -592,6 +849,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for
_
,
item
:=
range
available
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
item
.
account
.
ID
,
stickySessionTTL
)
}
...
...
@@ -621,13 +883,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
nil
,
errors
.
New
(
"no available accounts"
)
}
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
)
(
*
AccountSelectionResult
,
bool
)
{
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
,
sessionUUID
string
)
(
*
AccountSelectionResult
,
bool
)
{
ordered
:=
append
([]
*
Account
(
nil
),
candidates
...
)
sortAccountsByPriorityAndLastUsed
(
ordered
,
preferOAuth
)
for
_
,
acc
:=
range
ordered
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
acc
.
ID
,
acc
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
acc
.
ID
,
stickySessionTTL
)
}
...
...
@@ -656,51 +923,123 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func
(
s
*
GatewayService
)
checkClaudeCodeRestriction
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
int64
,
error
)
{
if
groupID
==
nil
{
return
groupID
,
nil
func
(
s
*
GatewayService
)
withGroupContext
(
ctx
context
.
Context
,
group
*
Group
)
context
.
Context
{
if
!
IsGroupContextValid
(
group
)
{
return
ctx
}
if
existing
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
&&
IsGroupContextValid
(
existing
)
{
return
ctx
}
return
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
}
// 强制平台模式不检查 Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
return
group
ID
,
nil
func
(
s
*
GatewayService
)
groupFromContext
(
ctx
context
.
Context
,
groupID
int64
)
*
Group
{
if
group
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
group
)
&&
group
.
ID
==
groupID
{
return
group
}
return
nil
}
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
func
(
s
*
GatewayService
)
resolveGroupByID
(
ctx
context
.
Context
,
groupID
int64
)
(
*
Group
,
error
)
{
if
group
:=
s
.
groupFromContext
(
ctx
,
groupID
);
group
!=
nil
{
return
group
,
nil
}
group
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
return
group
,
nil
}
if
!
group
.
ClaudeCodeOnly
{
return
groupID
,
nil
func
(
s
*
GatewayService
)
routingAccountIDsForRequest
(
ctx
context
.
Context
,
groupID
*
int64
,
requestedModel
string
,
platform
string
)
[]
int64
{
if
groupID
==
nil
||
requestedModel
==
""
||
platform
!=
PlatformAnthropic
{
return
nil
}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
*
groupID
)
if
err
!=
nil
||
group
==
nil
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v"
,
derefGroupID
(
groupID
),
requestedModel
,
platform
,
err
)
}
return
nil
}
// Preserve existing behavior: model routing only applies to anthropic groups.
if
group
.
Platform
!=
PlatformAnthropic
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s"
,
group
.
ID
,
group
.
Platform
,
requestedModel
)
}
return
nil
}
ids
:=
group
.
GetRoutingAccountIDs
(
requestedModel
)
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v"
,
group
.
ID
,
requestedModel
,
group
.
ModelRoutingEnabled
,
len
(
group
.
ModelRouting
),
ids
)
}
return
ids
}
// 分组启用了 Claude Code 限制
if
IsClaudeCodeClient
(
ctx
)
{
return
groupID
,
nil
func
(
s
*
GatewayService
)
resolveGatewayGroup
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Group
,
*
int64
,
error
)
{
if
groupID
==
nil
{
return
nil
,
nil
,
nil
}
// 非 Claude Code 客户端,检查降级分组
if
group
.
FallbackGroupID
!=
nil
{
return
group
.
FallbackGroupID
,
nil
currentID
:=
*
groupID
visited
:=
map
[
int64
]
struct
{}{}
for
{
if
_
,
seen
:=
visited
[
currentID
];
seen
{
return
nil
,
nil
,
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
currentID
]
=
struct
{}{}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
currentID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
if
!
group
.
ClaudeCodeOnly
||
IsClaudeCodeClient
(
ctx
)
{
return
group
,
&
currentID
,
nil
}
if
group
.
FallbackGroupID
==
nil
{
return
nil
,
nil
,
ErrClaudeCodeOnly
}
currentID
=
*
group
.
FallbackGroupID
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func
(
s
*
GatewayService
)
checkClaudeCodeRestriction
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Group
,
*
int64
,
error
)
{
if
groupID
==
nil
{
return
nil
,
groupID
,
nil
}
// 强制平台模式不检查 Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
return
nil
,
groupID
,
nil
}
return
nil
,
ErrClaudeCodeOnly
group
,
resolvedID
,
err
:=
s
.
resolveGatewayGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
return
group
,
resolvedID
,
nil
}
func
(
s
*
GatewayService
)
resolvePlatform
(
ctx
context
.
Context
,
groupID
*
int64
)
(
string
,
bool
,
error
)
{
func
(
s
*
GatewayService
)
resolvePlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
group
*
Group
)
(
string
,
bool
,
error
)
{
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
!=
""
{
return
forcePlatform
,
true
,
nil
}
if
group
!=
nil
{
return
group
.
Platform
,
false
,
nil
}
if
groupID
!=
nil
{
group
,
err
:=
s
.
groupRepo
.
Get
ByID
(
ctx
,
*
groupID
)
group
,
err
:=
s
.
resolveGroup
ByID
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
""
,
false
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
""
,
false
,
err
}
return
group
.
Platform
,
false
,
nil
}
...
...
@@ -708,6 +1047,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (s
}
func
(
s
*
GatewayService
)
listSchedulableAccounts
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
,
hasForcePlatform
bool
)
([]
Account
,
bool
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
}
useMixed
:=
(
platform
==
PlatformAnthropic
||
platform
==
PlatformGemini
)
&&
!
hasForcePlatform
if
useMixed
{
platforms
:=
[]
string
{
platform
,
PlatformAntigravity
}
...
...
@@ -784,6 +1126,114 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
return
s
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
}
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度,false 表示不可调度
func
(
s
*
GatewayService
)
isAccountSchedulableForWindowCost
(
ctx
context
.
Context
,
account
*
Account
,
isSticky
bool
)
bool
{
// 只检查 Anthropic OAuth/SetupToken 账号
if
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
return
true
}
limit
:=
account
.
GetWindowCostLimit
()
if
limit
<=
0
{
return
true
// 未启用窗口费用限制
}
// 尝试从缓存获取窗口费用
var
currentCost
float64
if
s
.
sessionLimitCache
!=
nil
{
if
cost
,
hit
,
err
:=
s
.
sessionLimitCache
.
GetWindowCost
(
ctx
,
account
.
ID
);
err
==
nil
&&
hit
{
currentCost
=
cost
goto
checkSchedulability
}
}
// 缓存未命中,从数据库查询
{
var
startTime
time
.
Time
if
account
.
SessionWindowStart
!=
nil
{
startTime
=
*
account
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
if
err
!=
nil
{
// 失败开放:查询失败时允许调度
return
true
}
// 使用标准费用(不含账号倍率)
currentCost
=
stats
.
StandardCost
// 设置缓存(忽略错误)
if
s
.
sessionLimitCache
!=
nil
{
_
=
s
.
sessionLimitCache
.
SetWindowCost
(
ctx
,
account
.
ID
,
currentCost
)
}
}
checkSchedulability
:
schedulability
:=
account
.
CheckWindowCostSchedulability
(
currentCost
)
switch
schedulability
{
case
WindowCostSchedulable
:
return
true
case
WindowCostStickyOnly
:
return
isSticky
case
WindowCostNotSchedulable
:
return
false
}
return
true
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func
(
s
*
GatewayService
)
checkAndRegisterSession
(
ctx
context
.
Context
,
account
*
Account
,
sessionUUID
string
)
bool
{
// 只检查 Anthropic OAuth/SetupToken 账号
if
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
return
true
}
maxSessions
:=
account
.
GetMaxSessions
()
if
maxSessions
<=
0
||
sessionUUID
==
""
{
return
true
// 未启用会话限制或无会话ID
}
if
s
.
sessionLimitCache
==
nil
{
return
true
// 缓存不可用时允许通过
}
idleTimeout
:=
time
.
Duration
(
account
.
GetSessionIdleTimeoutMinutes
())
*
time
.
Minute
allowed
,
err
:=
s
.
sessionLimitCache
.
RegisterSession
(
ctx
,
account
.
ID
,
sessionUUID
,
maxSessions
,
idleTimeout
)
if
err
!=
nil
{
// 失败开放:缓存错误时允许通过
return
true
}
return
allowed
}
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func
extractSessionUUID
(
metadataUserID
string
)
string
{
if
metadataUserID
==
""
{
return
""
}
if
match
:=
sessionIDRegex
.
FindStringSubmatch
(
metadataUserID
);
len
(
match
)
>
1
{
return
match
[
1
]
}
return
""
}
func
(
s
*
GatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
accountID
)
}
return
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
}
func
sortAccountsByPriorityAndLastUsed
(
accounts
[]
*
Account
,
preferOAuth
bool
)
{
sort
.
SliceStable
(
accounts
,
func
(
i
,
j
int
)
bool
{
a
,
b
:=
accounts
[
i
],
accounts
[
j
]
...
...
@@ -859,12 +1309,122 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
preferOAuth
:=
platform
==
PlatformGemini
routingAccountIDs
:=
s
.
routingAccountIDsForRequest
(
ctx
,
groupID
,
requestedModel
,
platform
)
var
accounts
[]
Account
accountsLoaded
:=
false
// ============ Model Routing (legacy path): apply before sticky session ============
// When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing
// so switching model can switch upstream account within the same sticky session.
if
len
(
routingAccountIDs
)
>
0
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v"
,
derefGroupID
(
groupID
),
requestedModel
,
platform
,
shortSessionHash
(
sessionHash
),
routingAccountIDs
)
}
// 1) Sticky session only applies if the bound account is within the routing set.
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
containsInt64
(
routingAccountIDs
,
accountID
)
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
accountID
)
}
return
account
,
nil
}
}
}
}
// 2) Select an account from the routed candidates.
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
==
""
{
hasForcePlatform
=
false
}
var
err
error
accounts
,
_
,
err
=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
accountsLoaded
=
true
routingSet
:=
make
(
map
[
int64
]
struct
{},
len
(
routingAccountIDs
))
for
_
,
id
:=
range
routingAccountIDs
{
if
id
>
0
{
routingSet
[
id
]
=
struct
{}{}
}
}
var
selected
*
Account
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
_
,
ok
:=
routingSet
[
acc
.
ID
];
!
ok
{
continue
}
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if
!
acc
.
IsSchedulable
()
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
if
selected
==
nil
{
selected
=
acc
continue
}
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
case
acc
.
LastUsedAt
!=
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (never used is preferred)
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
==
nil
:
if
preferOAuth
&&
acc
.
Type
!=
selected
.
Type
&&
acc
.
Type
==
AccountTypeOAuth
{
selected
=
acc
}
default
:
if
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
)
{
selected
=
acc
}
}
}
}
if
selected
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
selected
.
ID
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"set session account failed: session=%s account_id=%d err=%v"
,
sessionHash
,
selected
.
ID
,
err
)
}
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
selected
.
ID
)
}
return
selected
,
nil
}
log
.
Printf
(
"[ModelRouting] No routed accounts available for model=%s, falling back to normal selection"
,
requestedModel
)
}
// 1. 查询粘性会话
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
...
...
@@ -877,18 +1437,16 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
// 2. 获取可调度账号列表(单平台)
var
accounts
[]
Account
var
err
error
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
// 简易模式:忽略 groupID,查询所有可用账号
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
else
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
platform
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
if
!
accountsLoaded
{
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
==
""
{
hasForcePlatform
=
false
}
var
err
error
accounts
,
_
,
err
=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
}
// 3. 按优先级+最久未用选择(考虑模型支持)
...
...
@@ -898,6 +1456,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if
!
acc
.
IsSchedulable
()
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
...
...
@@ -948,15 +1511,123 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func
(
s
*
GatewayService
)
selectAccountWithMixedScheduling
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
nativePlatform
string
)
(
*
Account
,
error
)
{
platforms
:=
[]
string
{
nativePlatform
,
PlatformAntigravity
}
preferOAuth
:=
nativePlatform
==
PlatformGemini
routingAccountIDs
:=
s
.
routingAccountIDsForRequest
(
ctx
,
groupID
,
requestedModel
,
nativePlatform
)
var
accounts
[]
Account
accountsLoaded
:=
false
// ============ Model Routing (legacy path): apply before sticky session ============
if
len
(
routingAccountIDs
)
>
0
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v"
,
derefGroupID
(
groupID
),
requestedModel
,
nativePlatform
,
shortSessionHash
(
sessionHash
),
routingAccountIDs
)
}
// 1) Sticky session only applies if the bound account is within the routing set.
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
containsInt64
(
routingAccountIDs
,
accountID
)
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
accountID
)
}
return
account
,
nil
}
}
}
}
}
// 2) Select an account from the routed candidates.
var
err
error
accounts
,
_
,
err
=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
nativePlatform
,
false
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
accountsLoaded
=
true
routingSet
:=
make
(
map
[
int64
]
struct
{},
len
(
routingAccountIDs
))
for
_
,
id
:=
range
routingAccountIDs
{
if
id
>
0
{
routingSet
[
id
]
=
struct
{}{}
}
}
var
selected
*
Account
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
_
,
ok
:=
routingSet
[
acc
.
ID
];
!
ok
{
continue
}
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if
!
acc
.
IsSchedulable
()
{
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if
acc
.
Platform
==
PlatformAntigravity
&&
!
acc
.
IsMixedSchedulingEnabled
()
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
if
selected
==
nil
{
selected
=
acc
continue
}
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
case
acc
.
LastUsedAt
!=
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (never used is preferred)
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
==
nil
:
if
preferOAuth
&&
acc
.
Platform
==
PlatformGemini
&&
selected
.
Platform
==
PlatformGemini
&&
acc
.
Type
!=
selected
.
Type
&&
acc
.
Type
==
AccountTypeOAuth
{
selected
=
acc
}
default
:
if
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
)
{
selected
=
acc
}
}
}
}
if
selected
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
selected
.
ID
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"set session account failed: session=%s account_id=%d err=%v"
,
sessionHash
,
selected
.
ID
,
err
)
}
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
selected
.
ID
)
}
return
selected
,
nil
}
log
.
Printf
(
"[ModelRouting] No routed accounts available for model=%s, falling back to normal selection"
,
requestedModel
)
}
// 1. 查询粘性会话
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
...
...
@@ -971,15 +1642,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
// 2. 获取可调度账号列表
var
accounts
[]
Account
var
err
error
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatforms
(
ctx
,
*
groupID
,
platforms
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
if
!
accountsLoaded
{
var
err
error
accounts
,
_
,
err
=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
nativePlatform
,
false
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
}
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
...
...
@@ -989,6 +1657,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if
!
acc
.
IsSchedulable
()
{
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if
acc
.
Platform
==
PlatformAntigravity
&&
!
acc
.
IsMixedSchedulingEnabled
()
{
continue
...
...
@@ -1075,6 +1748,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
}
func
(
s
*
GatewayService
)
getOAuthToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
==
AccountTypeOAuth
&&
s
.
claudeTokenProvider
!=
nil
{
accessToken
,
err
:=
s
.
claudeTokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
accessToken
,
"oauth"
,
nil
}
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
...
...
@@ -1239,6 +1922,9 @@ func enforceCacheControlLimit(body []byte) []byte {
return
body
}
// 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
removeCacheControlFromThinkingBlocks
(
data
)
// 计算当前 cache_control 块数量
count
:=
countCacheControlBlocks
(
data
)
if
count
<=
maxCacheControlBlocks
{
...
...
@@ -1266,6 +1952,7 @@ func enforceCacheControlLimit(body []byte) []byte {
}
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
// 注意:thinking 块不支持 cache_control,统计时跳过
func
countCacheControlBlocks
(
data
map
[
string
]
any
)
int
{
count
:=
0
...
...
@@ -1273,6 +1960,10 @@ func countCacheControlBlocks(data map[string]any) int {
if
system
,
ok
:=
data
[
"system"
]
.
([]
any
);
ok
{
for
_
,
item
:=
range
system
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
// thinking 块不支持 cache_control,跳过
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"thinking"
{
continue
}
if
_
,
has
:=
m
[
"cache_control"
];
has
{
count
++
}
...
...
@@ -1287,6 +1978,10 @@ func countCacheControlBlocks(data map[string]any) int {
if
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
);
ok
{
for
_
,
item
:=
range
content
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
// thinking 块不支持 cache_control,跳过
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"thinking"
{
continue
}
if
_
,
has
:=
m
[
"cache_control"
];
has
{
count
++
}
...
...
@@ -1302,6 +1997,7 @@ func countCacheControlBlocks(data map[string]any) int {
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
// 返回 true 表示成功移除,false 表示没有可移除的
// 注意:跳过 thinking 块(它不支持 cache_control)
func
removeCacheControlFromMessages
(
data
map
[
string
]
any
)
bool
{
messages
,
ok
:=
data
[
"messages"
]
.
([]
any
)
if
!
ok
{
...
...
@@ -1319,6 +2015,10 @@ func removeCacheControlFromMessages(data map[string]any) bool {
}
for
_
,
item
:=
range
content
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
// thinking 块不支持 cache_control,跳过
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"thinking"
{
continue
}
if
_
,
has
:=
m
[
"cache_control"
];
has
{
delete
(
m
,
"cache_control"
)
return
true
...
...
@@ -1331,6 +2031,7 @@ func removeCacheControlFromMessages(data map[string]any) bool {
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
// 返回 true 表示成功移除,false 表示没有可移除的
// 注意:跳过 thinking 块(它不支持 cache_control)
func
removeCacheControlFromSystem
(
data
map
[
string
]
any
)
bool
{
system
,
ok
:=
data
[
"system"
]
.
([]
any
)
if
!
ok
{
...
...
@@ -1340,6 +2041,10 @@ func removeCacheControlFromSystem(data map[string]any) bool {
// 从尾部开始移除,保护开头注入的 Claude Code prompt
for
i
:=
len
(
system
)
-
1
;
i
>=
0
;
i
--
{
if
m
,
ok
:=
system
[
i
]
.
(
map
[
string
]
any
);
ok
{
// thinking 块不支持 cache_control,跳过
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"thinking"
{
continue
}
if
_
,
has
:=
m
[
"cache_control"
];
has
{
delete
(
m
,
"cache_control"
)
return
true
...
...
@@ -1349,6 +2054,44 @@ func removeCacheControlFromSystem(data map[string]any) bool {
return
false
}
// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control
// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段
func
removeCacheControlFromThinkingBlocks
(
data
map
[
string
]
any
)
{
// 清理 system 中的 thinking 块
if
system
,
ok
:=
data
[
"system"
]
.
([]
any
);
ok
{
for
_
,
item
:=
range
system
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"thinking"
{
if
_
,
has
:=
m
[
"cache_control"
];
has
{
delete
(
m
,
"cache_control"
)
log
.
Printf
(
"[Warning] Removed illegal cache_control from thinking block in system"
)
}
}
}
}
}
// 清理 messages 中的 thinking 块
if
messages
,
ok
:=
data
[
"messages"
]
.
([]
any
);
ok
{
for
msgIdx
,
msg
:=
range
messages
{
if
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
);
ok
{
if
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
);
ok
{
for
contentIdx
,
item
:=
range
content
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"thinking"
{
if
_
,
has
:=
m
[
"cache_control"
];
has
{
delete
(
m
,
"cache_control"
)
log
.
Printf
(
"[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]"
,
msgIdx
,
contentIdx
)
}
}
}
}
}
}
}
}
}
// Forward 转发请求到Claude API
func
(
s
*
GatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
...
...
@@ -1402,6 +2145,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
// Capture upstream request body for ops retry of this attempt.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1412,7 +2158,25 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
_
=
resp
.
Body
.
Close
()
}
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
setOpsUpstreamError
(
c
,
0
,
safeErr
,
""
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream request failed"
,
},
})
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %s"
,
safeErr
)
}
// 优先检测thinking block签名错误(400)并重试一次
...
...
@@ -1422,6 +2186,22 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_
=
resp
.
Body
.
Close
()
if
s
.
isThinkingBlockSignatureError
(
respBody
)
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"signature_error"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
Detail
:
func
()
string
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
return
truncateString
(
string
(
respBody
),
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
)
}
return
""
}(),
})
looksLikeToolSignatureError
:=
func
(
msg
string
)
bool
{
m
:=
strings
.
ToLower
(
msg
)
return
strings
.
Contains
(
m
,
"tool_use"
)
||
...
...
@@ -1458,6 +2238,21 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryRespBody
,
retryReadErr
:=
io
.
ReadAll
(
io
.
LimitReader
(
retryResp
.
Body
,
2
<<
20
))
_
=
retryResp
.
Body
.
Close
()
if
retryReadErr
==
nil
&&
retryResp
.
StatusCode
==
400
&&
s
.
isThinkingBlockSignatureError
(
retryRespBody
)
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
retryResp
.
StatusCode
,
UpstreamRequestID
:
retryResp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"signature_retry_thinking"
,
Message
:
extractUpstreamErrorMessage
(
retryRespBody
),
Detail
:
func
()
string
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
return
truncateString
(
string
(
retryRespBody
),
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
)
}
return
""
}(),
})
msg2
:=
extractUpstreamErrorMessage
(
retryRespBody
)
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
...
...
@@ -1472,6 +2267,14 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if
retryResp2
!=
nil
&&
retryResp2
.
Body
!=
nil
{
_
=
retryResp2
.
Body
.
Close
()
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"signature_retry_tools_request_error"
,
Message
:
sanitizeUpstreamErrorMessage
(
retryErr2
.
Error
()),
})
log
.
Printf
(
"Account %d: tool-downgrade signature retry failed: %v"
,
account
.
ID
,
retryErr2
)
}
else
{
log
.
Printf
(
"Account %d: tool-downgrade signature retry build failed: %v"
,
account
.
ID
,
buildErr2
)
...
...
@@ -1521,9 +2324,25 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
break
}
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
Detail
:
func
()
string
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
return
truncateString
(
string
(
respBody
),
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
)
}
return
""
}(),
})
log
.
Printf
(
"Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
maxRetryAttempts
,
delay
,
elapsed
,
maxRetryElapsed
)
_
=
resp
.
Body
.
Close
()
if
err
:=
sleepWithContext
(
ctx
,
delay
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1551,7 +2370,26 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理重试耗尽的情况
if
resp
.
StatusCode
>=
400
&&
s
.
shouldRetryUpstreamError
(
account
,
resp
.
StatusCode
)
{
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
s
.
handleRetryExhaustedSideEffects
(
ctx
,
resp
,
account
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry_exhausted_failover"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
Detail
:
func
()
string
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
return
truncateString
(
string
(
respBody
),
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
)
}
return
""
}(),
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
return
s
.
handleRetryExhaustedError
(
ctx
,
resp
,
c
,
account
)
...
...
@@ -1559,7 +2397,25 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理可切换账号的错误
if
resp
.
StatusCode
>=
400
&&
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
s
.
handleFailoverSideEffects
(
ctx
,
resp
,
account
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"failover"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
Detail
:
func
()
string
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
return
truncateString
(
string
(
respBody
),
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
)
}
return
""
}(),
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
...
...
@@ -1576,6 +2432,27 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
if
s
.
shouldFailoverOn400
(
respBody
)
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"failover_on_400"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
if
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
"Account %d: 400 error, attempting failover: %s"
,
...
...
@@ -1872,7 +2749,30 @@ func extractUpstreamErrorMessage(body []byte) string {
}
func
(
s
*
GatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
)
(
*
ForwardResult
,
error
)
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
body
),
maxBytes
)
}
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
// 处理上游错误,标记账号状态
shouldDisable
:=
false
...
...
@@ -1883,24 +2783,33 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
"Upstream error %d (account=%d platform=%s type=%s): %s"
,
resp
.
StatusCode
,
account
.
ID
,
account
.
Platform
,
account
.
Type
,
truncateForLog
(
body
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
),
)
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var
errType
,
errMsg
string
var
statusCode
int
switch
resp
.
StatusCode
{
case
400
:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
"Upstream 400 error (account=%d platform=%s type=%s): %s"
,
account
.
ID
,
account
.
Platform
,
account
.
Type
,
truncateForLog
(
body
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
),
)
}
c
.
Data
(
http
.
StatusBadRequest
,
"application/json"
,
body
)
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
summary
:=
upstreamMsg
if
summary
==
""
{
summary
=
truncateForLog
(
body
,
512
)
}
if
summary
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
resp
.
StatusCode
,
summary
)
case
401
:
statusCode
=
http
.
StatusBadGateway
errType
=
"upstream_error"
...
...
@@ -1936,11 +2845,14 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
},
})
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
if
upstreamMsg
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
func
(
s
*
GatewayService
)
handleRetryExhaustedSideEffects
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
account
*
Account
)
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
)
)
statusCode
:=
resp
.
StatusCode
// OAuth/Setup Token 账号的 403:标记账号异常
...
...
@@ -1954,7 +2866,7 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
}
func
(
s
*
GatewayService
)
handleFailoverSideEffects
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
account
*
Account
)
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
)
)
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
}
...
...
@@ -1962,8 +2874,45 @@ func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *ht
// OAuth 403:标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func
(
s
*
GatewayService
)
handleRetryExhaustedError
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
)
(
*
ForwardResult
,
error
)
{
// Capture upstream error body before side-effects consume the stream.
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
s
.
handleRetryExhaustedSideEffects
(
ctx
,
resp
,
account
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry_exhausted"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
"Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s"
,
resp
.
StatusCode
,
account
.
ID
,
account
.
Platform
,
account
.
Type
,
truncateForLog
(
respBody
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
),
)
}
// 返回统一的重试耗尽错误响应
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"type"
:
"error"
,
...
...
@@ -1973,7 +2922,10 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
},
})
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (retries exhausted)"
,
resp
.
StatusCode
)
if
upstreamMsg
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (retries exhausted)"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (retries exhausted) message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
// streamingResult 流式响应结果
...
...
@@ -2154,6 +3106,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if
s
.
rateLimitService
!=
nil
{
s
.
rateLimitService
.
HandleStreamTimeout
(
ctx
,
account
,
originalModel
)
}
sendErrorEvent
(
"stream_timeout"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
...
...
@@ -2307,6 +3263,7 @@ type RecordUsageInput struct {
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
...
...
@@ -2366,30 +3323,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if
result
.
ImageSize
!=
""
{
imageSize
=
&
result
.
ImageSize
}
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
InputCost
:
cost
.
InputCost
,
OutputCost
:
cost
.
OutputCost
,
CacheCreationCost
:
cost
.
CacheCreationCost
,
CacheReadCost
:
cost
.
CacheReadCost
,
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
FirstTokenMs
:
result
.
FirstTokenMs
,
ImageCount
:
result
.
ImageCount
,
ImageSize
:
imageSize
,
CreatedAt
:
time
.
Now
(),
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
InputCost
:
cost
.
InputCost
,
OutputCost
:
cost
.
OutputCost
,
CacheCreationCost
:
cost
.
CacheCreationCost
,
CacheReadCost
:
cost
.
CacheReadCost
,
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
AccountRateMultiplier
:
&
accountRateMultiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
FirstTokenMs
:
result
.
FirstTokenMs
,
ImageCount
:
result
.
ImageCount
,
ImageSize
:
imageSize
,
CreatedAt
:
time
.
Now
(),
}
// 添加 UserAgent
...
...
@@ -2397,6 +3356,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog
.
UserAgent
=
&
input
.
UserAgent
}
// 添加 IPAddress
if
input
.
IPAddress
!=
""
{
usageLog
.
IPAddress
=
&
input
.
IPAddress
}
// 添加分组和订阅关联
if
apiKey
.
GroupID
!=
nil
{
usageLog
.
GroupID
=
apiKey
.
GroupID
...
...
@@ -2497,6 +3461,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
setOpsUpstreamError
(
c
,
0
,
sanitizeUpstreamErrorMessage
(
err
.
Error
()),
""
)
s
.
countTokensError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Request failed"
)
return
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
...
...
@@ -2534,6 +3499,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 标记账号状态(429/529等)
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
// 记录上游错误摘要便于排障(不回显请求内容)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
...
...
@@ -2555,7 +3532,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
errMsg
=
"Service overloaded"
}
s
.
countTokensError
(
c
,
resp
.
StatusCode
,
"upstream_error"
,
errMsg
)
return
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
if
upstreamMsg
==
""
{
return
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
}
return
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
// 透传成功响应
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
b9b4db3d
...
...
@@ -40,6 +40,7 @@ type GeminiMessagesCompatService struct {
accountRepo
AccountRepository
groupRepo
GroupRepository
cache
GatewayCache
schedulerSnapshot
*
SchedulerSnapshotService
tokenProvider
*
GeminiTokenProvider
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
...
...
@@ -51,6 +52,7 @@ func NewGeminiMessagesCompatService(
accountRepo
AccountRepository
,
groupRepo
GroupRepository
,
cache
GatewayCache
,
schedulerSnapshot
*
SchedulerSnapshotService
,
tokenProvider
*
GeminiTokenProvider
,
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
...
...
@@ -61,6 +63,7 @@ func NewGeminiMessagesCompatService(
accountRepo
:
accountRepo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
schedulerSnapshot
:
schedulerSnapshot
,
tokenProvider
:
tokenProvider
,
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
...
...
@@ -86,9 +89,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
var
group
*
Group
if
ctxGroup
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
ctxGroup
)
&&
ctxGroup
.
ID
==
*
groupID
{
group
=
ctxGroup
}
else
{
var
err
error
group
,
err
=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
}
platform
=
group
.
Platform
}
else
{
...
...
@@ -99,12 +108,6 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling
:=
platform
==
PlatformGemini
&&
!
hasForcePlatform
var
queryPlatforms
[]
string
if
useMixedScheduling
{
queryPlatforms
=
[]
string
{
PlatformGemini
,
PlatformAntigravity
}
}
else
{
queryPlatforms
=
[]
string
{
platform
}
}
cacheKey
:=
"gemini:"
+
sessionHash
...
...
@@ -112,7 +115,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
valid
:=
false
...
...
@@ -143,22 +146,16 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
var
accounts
[]
Account
var
err
error
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatforms
(
ctx
,
*
groupID
,
queryPlatforms
)
accounts
,
err
:=
s
.
listSchedulableAccountsOnce
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if
len
(
accounts
)
==
0
&&
groupID
!=
nil
&&
hasForcePlatform
{
accounts
,
err
=
s
.
listSchedulableAccountsOnce
(
ctx
,
nil
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if
len
(
accounts
)
==
0
&&
hasForcePlatform
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatforms
(
ctx
,
queryPlatforms
)
}
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatforms
(
ctx
,
queryPlatforms
)
}
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
var
selected
*
Account
...
...
@@ -239,6 +236,31 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return
s
.
antigravityGatewayService
}
func
(
s
*
GeminiMessagesCompatService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
accountID
)
}
return
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
}
func
(
s
*
GeminiMessagesCompatService
)
listSchedulableAccountsOnce
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
,
hasForcePlatform
bool
)
([]
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
accounts
,
_
,
err
:=
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
return
accounts
,
err
}
useMixedScheduling
:=
platform
==
PlatformGemini
&&
!
hasForcePlatform
queryPlatforms
:=
[]
string
{
platform
}
if
useMixedScheduling
{
queryPlatforms
=
[]
string
{
platform
,
PlatformAntigravity
}
}
if
groupID
!=
nil
{
return
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatforms
(
ctx
,
*
groupID
,
queryPlatforms
)
}
return
s
.
accountRepo
.
ListSchedulableByPlatforms
(
ctx
,
queryPlatforms
)
}
func
(
s
*
GeminiMessagesCompatService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
...
...
@@ -260,13 +282,7 @@ func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (strin
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func
(
s
*
GeminiMessagesCompatService
)
HasAntigravityAccounts
(
ctx
context
.
Context
,
groupID
*
int64
)
(
bool
,
error
)
{
var
accounts
[]
Account
var
err
error
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
PlatformAntigravity
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
PlatformAntigravity
)
}
accounts
,
err
:=
s
.
listSchedulableAccountsOnce
(
ctx
,
groupID
,
PlatformAntigravity
,
false
)
if
err
!=
nil
{
return
false
,
err
}
...
...
@@ -282,13 +298,7 @@ func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context
// 3) OAuth accounts explicitly marked as ai_studio
// 4) Any remaining Gemini accounts (fallback)
func
(
s
*
GeminiMessagesCompatService
)
SelectAccountForAIStudioEndpoints
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Account
,
error
)
{
var
accounts
[]
Account
var
err
error
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
PlatformGemini
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
PlatformGemini
)
}
accounts
,
err
:=
s
.
listSchedulableAccountsOnce
(
ctx
,
groupID
,
PlatformGemini
,
true
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
...
...
@@ -535,14 +545,30 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader
=
idHeader
// Capture upstream request body for ops retry of this attempt.
if
c
!=
nil
{
// In this code path `body` is already the JSON sent to upstream.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
}
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
if
attempt
<
geminiMaxRetries
{
log
.
Printf
(
"Gemini account %d: upstream request failed, retry %d/%d: %v"
,
account
.
ID
,
attempt
,
geminiMaxRetries
,
err
)
sleepGeminiBackoff
(
attempt
)
continue
}
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries: "
+
sanitizeUpstreamErrorMessage
(
err
.
Error
()))
setOpsUpstreamError
(
c
,
0
,
safeErr
,
""
)
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries: "
+
safeErr
)
}
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
...
...
@@ -552,6 +578,31 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
_
=
resp
.
Body
.
Close
()
if
isGeminiSignatureRelatedError
(
respBody
)
{
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"signature_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
var
strippedClaudeBody
[]
byte
stageName
:=
""
switch
signatureRetryStage
{
...
...
@@ -602,6 +653,31 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
}
if
attempt
<
geminiMaxRetries
{
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"retry"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
log
.
Printf
(
"Gemini account %d: upstream status %d, retry %d/%d"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
geminiMaxRetries
)
sleepGeminiBackoff
(
attempt
)
continue
...
...
@@ -627,12 +703,64 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
if
tempMatched
{
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"failover"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
if
s
.
shouldFailoverGeminiUpstreamError
(
resp
.
StatusCode
)
{
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"failover"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
return
nil
,
s
.
writeGeminiMappedError
(
c
,
resp
.
StatusCode
,
respBody
)
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
}
return
nil
,
s
.
writeGeminiMappedError
(
c
,
account
,
resp
.
StatusCode
,
upstreamReqID
,
respBody
)
}
requestID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
...
...
@@ -855,8 +983,23 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader
=
idHeader
// Capture upstream request body for ops retry of this attempt.
if
c
!=
nil
{
// In this code path `body` is already the JSON sent to upstream.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
}
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
if
attempt
<
geminiMaxRetries
{
log
.
Printf
(
"Gemini account %d: upstream request failed, retry %d/%d: %v"
,
account
.
ID
,
attempt
,
geminiMaxRetries
,
err
)
sleepGeminiBackoff
(
attempt
)
...
...
@@ -874,7 +1017,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
FirstTokenMs
:
nil
,
},
nil
}
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Upstream request failed after retries: "
+
sanitizeUpstreamErrorMessage
(
err
.
Error
()))
setOpsUpstreamError
(
c
,
0
,
safeErr
,
""
)
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Upstream request failed after retries: "
+
safeErr
)
}
if
resp
.
StatusCode
>=
400
&&
s
.
shouldRetryGeminiUpstreamError
(
account
,
resp
.
StatusCode
)
{
...
...
@@ -893,6 +1037,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
}
if
attempt
<
geminiMaxRetries
{
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
upstreamReqID
,
Kind
:
"retry"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
log
.
Printf
(
"Gemini account %d: upstream status %d, retry %d/%d"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
geminiMaxRetries
)
sleepGeminiBackoff
(
attempt
)
continue
...
...
@@ -956,19 +1125,87 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
if
tempMatched
{
evBody
:=
unwrapIfNeeded
(
isOAuth
,
respBody
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
evBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
evBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"failover"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
if
s
.
shouldFailoverGeminiUpstreamError
(
resp
.
StatusCode
)
{
evBody
:=
unwrapIfNeeded
(
isOAuth
,
respBody
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
evBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
evBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"failover"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
respBody
=
unwrapIfNeeded
(
isOAuth
,
respBody
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
log
.
Printf
(
"[Gemini] native upstream error %d: %s"
,
resp
.
StatusCode
,
truncateForLog
(
respBody
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
))
}
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"application/json"
}
c
.
Data
(
resp
.
StatusCode
,
contentType
,
respBody
)
return
nil
,
fmt
.
Errorf
(
"gemini upstream error: %d"
,
resp
.
StatusCode
)
if
upstreamMsg
==
""
{
return
nil
,
fmt
.
Errorf
(
"gemini upstream error: %d"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"gemini upstream error: %d message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
var
usage
*
ClaudeUsage
...
...
@@ -1070,7 +1307,33 @@ func sanitizeUpstreamErrorMessage(msg string) string {
return
sensitiveQueryParamRegex
.
ReplaceAllString
(
msg
,
`$1***`
)
}
func
(
s
*
GeminiMessagesCompatService
)
writeGeminiMappedError
(
c
*
gin
.
Context
,
upstreamStatus
int
,
body
[]
byte
)
error
{
func
(
s
*
GeminiMessagesCompatService
)
writeGeminiMappedError
(
c
*
gin
.
Context
,
account
*
Account
,
upstreamStatus
int
,
upstreamRequestID
string
,
body
[]
byte
)
error
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
if
maxBytes
<=
0
{
maxBytes
=
2048
}
upstreamDetail
=
truncateString
(
string
(
body
),
maxBytes
)
}
setOpsUpstreamError
(
c
,
upstreamStatus
,
upstreamMsg
,
upstreamDetail
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
upstreamStatus
,
UpstreamRequestID
:
upstreamRequestID
,
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
log
.
Printf
(
"[Gemini] upstream error %d: %s"
,
upstreamStatus
,
truncateForLog
(
body
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
))
}
var
statusCode
int
var
errType
,
errMsg
string
...
...
@@ -1178,7 +1441,10 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, ups
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
errMsg
},
})
return
fmt
.
Errorf
(
"upstream error: %d"
,
upstreamStatus
)
if
upstreamMsg
==
""
{
return
fmt
.
Errorf
(
"upstream error: %d"
,
upstreamStatus
)
}
return
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
upstreamStatus
,
upstreamMsg
)
}
type
claudeErrorMapping
struct
{
...
...
backend/internal/service/gemini_multiplatform_test.go
View file @
b9b4db3d
...
...
@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
...
...
@@ -127,6 +128,9 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func
(
m
*
mockAccountRepoForGemini
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
...
...
@@ -140,6 +144,9 @@ func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64)
func
(
m
*
mockAccountRepoForGemini
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
...
...
@@ -155,10 +162,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
type
mockGroupRepoForGemini
struct
{
groups
map
[
int64
]
*
Group
groups
map
[
int64
]
*
Group
getByIDCalls
int
getByIDLiteCalls
int
}
func
(
m
*
mockGroupRepoForGemini
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
errors
.
New
(
"group not found"
)
}
func
(
m
*
mockGroupRepoForGemini
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDLiteCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
...
...
@@ -251,6 +269,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
require
.
Equal
(
t
,
PlatformGemini
,
acc
.
Platform
,
"无分组时应只返回 gemini 平台账户"
)
}
func
TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
7
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformGemini
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
7
)
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformGemini
},
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
...
...
backend/internal/service/gemini_token_cache.go
View file @
b9b4db3d
...
...
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
...
...
backend/internal/service/gemini_token_provider.go
View file @
b9b4db3d
...
...
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"not a gemini oauth account"
)
}
cacheKey
:=
g
eminiTokenCacheKey
(
account
)
cacheKey
:=
G
eminiTokenCacheKey
(
account
)
// 1) Try cache first.
if
p
.
tokenCache
!=
nil
{
...
...
@@ -151,10 +151,10 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
accessToken
,
nil
}
func
g
eminiTokenCacheKey
(
account
*
Account
)
string
{
func
G
eminiTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
return
projectID
return
"gemini:"
+
projectID
}
return
"account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
return
"
gemini:
account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
Prev
1
…
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment