Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
538ae31a
Commit
538ae31a
authored
Apr 30, 2026
by
陈曦
Browse files
merge v0.1.121 and fixed conflict
parents
74828a7c
48912014
Pipeline
#82338
passed with stage
in 17 seconds
Changes
151
Pipelines
3
Hide whitespace changes
Inline
Side-by-side
backend/internal/pkg/apicompat/anthropic_to_responses.go
View file @
538ae31a
...
...
@@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
// {"type":"auto"} → "auto"
// {"type":"any"} → "required"
// {"type":"none"} → "none"
// {"type":"tool","name":"X"} → {"type":"function","
function":{"
name":"X"}
}
// {"type":"tool","name":"X"} → {"type":"function","name":"X"}
func
convertAnthropicToolChoiceToResponses
(
raw
json
.
RawMessage
)
(
json
.
RawMessage
,
error
)
{
var
tc
struct
{
Type
string
`json:"type"`
...
...
@@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
return
json
.
Marshal
(
"none"
)
case
"tool"
:
return
json
.
Marshal
(
map
[
string
]
any
{
"type"
:
"function"
,
"function"
:
map
[
string
]
string
{
"name"
:
tc
.
Name
}
,
"type"
:
"function"
,
"name"
:
tc
.
Name
,
})
default
:
// Pass through unknown types as-is
...
...
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
View file @
538ae31a
...
...
@@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
var
tc
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
resp
.
ToolChoice
,
&
tc
))
assert
.
Equal
(
t
,
"function"
,
tc
[
"type"
])
assert
.
Equal
(
t
,
"get_weather"
,
tc
[
"name"
])
assert
.
NotContains
(
t
,
tc
,
"function"
)
}
func
TestChatCompletionsToResponses_ServiceTier
(
t
*
testing
.
T
)
{
...
...
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
View file @
538ae31a
...
...
@@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R
//
// "auto" → "auto"
// "none" → "none"
// {"name":"X"} → {"type":"function","
function":{"
name":"X"}
}
// {"name":"X"} → {"type":"function","name":"X"}
func
convertChatFunctionCallToToolChoice
(
raw
json
.
RawMessage
)
(
json
.
RawMessage
,
error
)
{
// Try string first ("auto", "none", etc.) — pass through as-is.
var
s
string
...
...
@@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage,
return
nil
,
err
}
return
json
.
Marshal
(
map
[
string
]
any
{
"type"
:
"function"
,
"function"
:
map
[
string
]
string
{
"name"
:
obj
.
Name
}
,
"type"
:
"function"
,
"name"
:
obj
.
Name
,
})
}
backend/internal/pkg/apicompat/responses_to_anthropic_request.go
View file @
538ae31a
...
...
@@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
// "auto" → {"type":"auto"}
// "required" → {"type":"any"}
// "none" → {"type":"none"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
// {"type":"function","name":"X"} → {"type":"tool","name":"X"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy
func
convertResponsesToAnthropicToolChoice
(
raw
json
.
RawMessage
)
(
json
.
RawMessage
,
error
)
{
// Try as string first
var
s
string
...
...
@@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage
// Try as object with type=function
var
tc
struct
{
Type
string
`json:"type"`
Name
string
`json:"name"`
Function
struct
{
Name
string
`json:"name"`
}
`json:"function"`
}
if
err
:=
json
.
Unmarshal
(
raw
,
&
tc
);
err
==
nil
&&
tc
.
Type
==
"function"
&&
tc
.
Function
.
Name
!=
""
{
if
err
:=
json
.
Unmarshal
(
raw
,
&
tc
);
err
==
nil
&&
tc
.
Type
==
"function"
{
name
:=
strings
.
TrimSpace
(
tc
.
Name
)
if
name
==
""
{
name
=
strings
.
TrimSpace
(
tc
.
Function
.
Name
)
}
if
name
==
""
{
return
raw
,
nil
}
return
json
.
Marshal
(
map
[
string
]
string
{
"type"
:
"tool"
,
"name"
:
tc
.
Function
.
N
ame
,
"name"
:
n
ame
,
})
}
...
...
backend/internal/pkg/httputil/body.go
View file @
538ae31a
...
...
@@ -2,16 +2,28 @@ package httputil
import
(
"bytes"
"compress/gzip"
"compress/zlib"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/klauspost/compress/zstd"
)
const
(
requestBodyReadInitCap
=
512
requestBodyReadMaxInitCap
=
1
<<
20
// maxDecompressedBodySize limits the decompressed request body to 64 MB
// to prevent decompression bomb attacks.
maxDecompressedBodySize
=
64
<<
20
)
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based
// on content length, transparently decoding any Content-Encoding the upstream
// client used to compress the body (zstd, gzip, deflate).
func
ReadRequestBodyWithPrealloc
(
req
*
http
.
Request
)
([]
byte
,
error
)
{
if
req
==
nil
||
req
.
Body
==
nil
{
return
nil
,
nil
...
...
@@ -33,5 +45,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if
_
,
err
:=
io
.
Copy
(
buf
,
req
.
Body
);
err
!=
nil
{
return
nil
,
err
}
return
buf
.
Bytes
(),
nil
raw
:=
buf
.
Bytes
()
enc
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
req
.
Header
.
Get
(
"Content-Encoding"
)))
if
enc
==
""
||
enc
==
"identity"
{
return
raw
,
nil
}
decoded
,
err
:=
decompressRequestBody
(
enc
,
raw
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"decode Content-Encoding %q: %w"
,
enc
,
err
)
}
req
.
Header
.
Del
(
"Content-Encoding"
)
req
.
Header
.
Del
(
"Content-Length"
)
req
.
ContentLength
=
int64
(
len
(
decoded
))
return
decoded
,
nil
}
func
decompressRequestBody
(
encoding
string
,
raw
[]
byte
)
([]
byte
,
error
)
{
switch
encoding
{
case
"zstd"
:
dec
,
err
:=
zstd
.
NewReader
(
bytes
.
NewReader
(
raw
))
if
err
!=
nil
{
return
nil
,
err
}
defer
dec
.
Close
()
return
io
.
ReadAll
(
io
.
LimitReader
(
dec
,
maxDecompressedBodySize
))
case
"gzip"
,
"x-gzip"
:
gr
,
err
:=
gzip
.
NewReader
(
bytes
.
NewReader
(
raw
))
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
gr
.
Close
()
}()
return
io
.
ReadAll
(
io
.
LimitReader
(
gr
,
maxDecompressedBodySize
))
case
"deflate"
:
zr
,
err
:=
zlib
.
NewReader
(
bytes
.
NewReader
(
raw
))
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
zr
.
Close
()
}()
return
io
.
ReadAll
(
io
.
LimitReader
(
zr
,
maxDecompressedBodySize
))
default
:
return
nil
,
errors
.
New
(
"unsupported Content-Encoding"
)
}
}
backend/internal/pkg/httputil/body_test.go
0 → 100644
View file @
538ae31a
package
httputil
import
(
"bytes"
"compress/gzip"
"compress/zlib"
"net/http"
"strings"
"testing"
"github.com/klauspost/compress/zstd"
)
const
samplePayload
=
`{"model":"gpt-5.5","input":"hi","stream":false}`
func
newRequestWithBody
(
t
*
testing
.
T
,
body
[]
byte
,
encoding
string
)
*
http
.
Request
{
t
.
Helper
()
req
,
err
:=
http
.
NewRequest
(
http
.
MethodPost
,
"/v1/responses"
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
t
.
Fatalf
(
"NewRequest: %v"
,
err
)
}
if
encoding
!=
""
{
req
.
Header
.
Set
(
"Content-Encoding"
,
encoding
)
}
req
.
ContentLength
=
int64
(
len
(
body
))
return
req
}
func
TestReadRequestBodyWithPrealloc_PassesThroughIdentity
(
t
*
testing
.
T
)
{
req
:=
newRequestWithBody
(
t
,
[]
byte
(
samplePayload
),
""
)
got
,
err
:=
ReadRequestBodyWithPrealloc
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
string
(
got
)
!=
samplePayload
{
t
.
Fatalf
(
"body mismatch: got %q"
,
got
)
}
}
func
TestReadRequestBodyWithPrealloc_DecodesZstd
(
t
*
testing
.
T
)
{
enc
,
_
:=
zstd
.
NewWriter
(
nil
)
compressed
:=
enc
.
EncodeAll
([]
byte
(
samplePayload
),
nil
)
_
=
enc
.
Close
()
req
:=
newRequestWithBody
(
t
,
compressed
,
"zstd"
)
got
,
err
:=
ReadRequestBodyWithPrealloc
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
string
(
got
)
!=
samplePayload
{
t
.
Fatalf
(
"body mismatch: got %q"
,
got
)
}
if
req
.
Header
.
Get
(
"Content-Encoding"
)
!=
""
{
t
.
Fatalf
(
"Content-Encoding should be cleared after decoding"
)
}
if
req
.
ContentLength
!=
int64
(
len
(
samplePayload
))
{
t
.
Fatalf
(
"ContentLength not updated: %d"
,
req
.
ContentLength
)
}
}
func
TestReadRequestBodyWithPrealloc_DecodesGzip
(
t
*
testing
.
T
)
{
var
buf
bytes
.
Buffer
gw
:=
gzip
.
NewWriter
(
&
buf
)
if
_
,
err
:=
gw
.
Write
([]
byte
(
samplePayload
));
err
!=
nil
{
t
.
Fatalf
(
"gzip write: %v"
,
err
)
}
if
err
:=
gw
.
Close
();
err
!=
nil
{
t
.
Fatalf
(
"gzip close: %v"
,
err
)
}
req
:=
newRequestWithBody
(
t
,
buf
.
Bytes
(),
"gzip"
)
got
,
err
:=
ReadRequestBodyWithPrealloc
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
string
(
got
)
!=
samplePayload
{
t
.
Fatalf
(
"body mismatch: got %q"
,
got
)
}
}
func
TestReadRequestBodyWithPrealloc_DecodesDeflate
(
t
*
testing
.
T
)
{
var
buf
bytes
.
Buffer
zw
:=
zlib
.
NewWriter
(
&
buf
)
if
_
,
err
:=
zw
.
Write
([]
byte
(
samplePayload
));
err
!=
nil
{
t
.
Fatalf
(
"zlib write: %v"
,
err
)
}
if
err
:=
zw
.
Close
();
err
!=
nil
{
t
.
Fatalf
(
"zlib close: %v"
,
err
)
}
req
:=
newRequestWithBody
(
t
,
buf
.
Bytes
(),
"deflate"
)
got
,
err
:=
ReadRequestBodyWithPrealloc
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
string
(
got
)
!=
samplePayload
{
t
.
Fatalf
(
"body mismatch: got %q"
,
got
)
}
}
func
TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding
(
t
*
testing
.
T
)
{
req
:=
newRequestWithBody
(
t
,
[]
byte
(
samplePayload
),
"br"
)
_
,
err
:=
ReadRequestBodyWithPrealloc
(
req
)
if
err
==
nil
{
t
.
Fatal
(
"expected error for unsupported encoding, got nil"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"br"
)
{
t
.
Fatalf
(
"error should mention encoding, got %v"
,
err
)
}
}
func
TestReadRequestBodyWithPrealloc_RejectsCorruptZstd
(
t
*
testing
.
T
)
{
req
:=
newRequestWithBody
(
t
,
[]
byte
(
"not actually zstd"
),
"zstd"
)
_
,
err
:=
ReadRequestBodyWithPrealloc
(
req
)
if
err
==
nil
{
t
.
Fatal
(
"expected error for corrupt zstd body, got nil"
)
}
}
func
TestReadRequestBodyWithPrealloc_NilBody
(
t
*
testing
.
T
)
{
req
,
err
:=
http
.
NewRequest
(
http
.
MethodPost
,
"/v1/responses"
,
nil
)
if
err
!=
nil
{
t
.
Fatalf
(
"NewRequest: %v"
,
err
)
}
got
,
err
:=
ReadRequestBodyWithPrealloc
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
got
!=
nil
{
t
.
Fatalf
(
"expected nil body, got %q"
,
got
)
}
}
func
TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding
(
t
*
testing
.
T
)
{
req
:=
newRequestWithBody
(
t
,
[]
byte
(
samplePayload
),
"identity"
)
got
,
err
:=
ReadRequestBodyWithPrealloc
(
req
)
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
string
(
got
)
!=
samplePayload
{
t
.
Fatalf
(
"body mismatch: got %q"
,
got
)
}
}
backend/internal/repository/account_repo_integration_test.go
View file @
538ae31a
...
...
@@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi
return
true
,
nil
}
func
(
s
*
schedulerCacheRecorder
)
UnlockBucket
(
ctx
context
.
Context
,
bucket
service
.
SchedulerBucket
)
error
{
return
nil
}
func
(
s
*
schedulerCacheRecorder
)
ListBuckets
(
ctx
context
.
Context
)
([]
service
.
SchedulerBucket
,
error
)
{
return
nil
,
nil
}
...
...
backend/internal/repository/scheduler_cache.go
View file @
538ae31a
...
...
@@ -24,6 +24,49 @@ const (
defaultSchedulerSnapshotMGetChunkSize
=
128
defaultSchedulerSnapshotWriteChunkSize
=
256
// snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
// 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
snapshotGraceTTLSeconds
=
60
)
var
(
// activateSnapshotScript 原子 CAS 切换快照版本。
// 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
// 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。
//
// KEYS[1] = activeKey (sched:active:{bucket})
// KEYS[2] = readyKey (sched:ready:{bucket})
// KEYS[3] = bucketSetKey (sched:buckets)
// KEYS[4] = snapshotKey (新写入的快照 key)
// ARGV[1] = 新版本号字符串
// ARGV[2] = bucket 字符串 (用于 SADD)
// ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
// ARGV[4] = 宽限期 TTL 秒数
//
// 返回 1 = 已激活, 0 = 版本过旧未激活
activateSnapshotScript
=
redis
.
NewScript
(
`
local currentActive = redis.call('GET', KEYS[1])
local newVersion = tonumber(ARGV[1])
if currentActive ~= false then
local curVersion = tonumber(currentActive)
if curVersion and newVersion < curVersion then
redis.call('DEL', KEYS[4])
return 0
end
end
redis.call('SET', KEYS[1], ARGV[1])
redis.call('SET', KEYS[2], '1')
redis.call('SADD', KEYS[3], ARGV[2])
if currentActive ~= false and currentActive ~= ARGV[1] then
redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
end
return 1
`
)
)
type
schedulerCache
struct
{
...
...
@@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
}
func
(
c
*
schedulerCache
)
SetSnapshot
(
ctx
context
.
Context
,
bucket
service
.
SchedulerBucket
,
accounts
[]
service
.
Account
)
error
{
activeKey
:=
schedulerBucketKey
(
schedulerActivePrefix
,
bucket
)
oldActive
,
_
:=
c
.
rdb
.
Get
(
ctx
,
activeKey
)
.
Result
()
// Phase 1: 分配新版本号并写入快照数据。
// INCR 保证每个调用方获得唯一递增版本号。
// 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。
versionKey
:=
schedulerBucketKey
(
schedulerVersionPrefix
,
bucket
)
version
,
err
:=
c
.
rdb
.
Incr
(
ctx
,
versionKey
)
.
Result
()
if
err
!=
nil
{
...
...
@@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
return
err
}
pipe
:=
c
.
rdb
.
Pipeline
()
if
len
(
accounts
)
>
0
{
// 使用序号作为 score,保持数据库返回的排序语义。
members
:=
make
([]
redis
.
Z
,
0
,
len
(
accounts
))
...
...
@@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
Member
:
strconv
.
FormatInt
(
account
.
ID
,
10
),
})
}
pipe
:=
c
.
rdb
.
Pipeline
()
for
start
:=
0
;
start
<
len
(
members
);
start
+=
c
.
writeChunkSize
{
end
:=
start
+
c
.
writeChunkSize
if
end
>
len
(
members
)
{
...
...
@@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
}
pipe
.
ZAdd
(
ctx
,
snapshotKey
,
members
[
start
:
end
]
...
)
}
}
else
{
pipe
.
Del
(
ctx
,
snapshotKey
)
}
pipe
.
Set
(
ctx
,
activeKey
,
versionStr
,
0
)
pipe
.
Set
(
ctx
,
schedulerBucketKey
(
schedulerReadyPrefix
,
bucket
),
"1"
,
0
)
pipe
.
SAdd
(
ctx
,
schedulerBucketSetKey
,
bucket
.
String
())
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
err
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
err
}
}
if
oldActive
!=
""
&&
oldActive
!=
versionStr
{
_
=
c
.
rdb
.
Del
(
ctx
,
schedulerSnapshotKey
(
bucket
,
oldActive
))
.
Err
()
// Phase 2: 原子 CAS 激活版本。
// Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
// 防止并发写入导致版本回滚。
// 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。
activeKey
:=
schedulerBucketKey
(
schedulerActivePrefix
,
bucket
)
readyKey
:=
schedulerBucketKey
(
schedulerReadyPrefix
,
bucket
)
snapshotKeyPrefix
:=
fmt
.
Sprintf
(
"%s%d:%s:%s:v"
,
schedulerSnapshotPrefix
,
bucket
.
GroupID
,
bucket
.
Platform
,
bucket
.
Mode
)
keys
:=
[]
string
{
activeKey
,
readyKey
,
schedulerBucketSetKey
,
snapshotKey
}
args
:=
[]
any
{
versionStr
,
bucket
.
String
(),
snapshotKeyPrefix
,
snapshotGraceTTLSeconds
}
_
,
err
=
activateSnapshotScript
.
Run
(
ctx
,
c
.
rdb
,
keys
,
args
...
)
.
Result
()
if
err
!=
nil
{
return
err
}
return
nil
...
...
@@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched
return
c
.
rdb
.
SetNX
(
ctx
,
key
,
time
.
Now
()
.
UnixNano
(),
ttl
)
.
Result
()
}
func
(
c
*
schedulerCache
)
UnlockBucket
(
ctx
context
.
Context
,
bucket
service
.
SchedulerBucket
)
error
{
key
:=
schedulerBucketKey
(
schedulerLockPrefix
,
bucket
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
func
(
c
*
schedulerCache
)
ListBuckets
(
ctx
context
.
Context
)
([]
service
.
SchedulerBucket
,
error
)
{
raw
,
err
:=
c
.
rdb
.
SMembers
(
ctx
,
schedulerBucketSetKey
)
.
Result
()
if
err
!=
nil
{
...
...
@@ -394,11 +449,69 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
SessionWindowStart
:
account
.
SessionWindowStart
,
SessionWindowEnd
:
account
.
SessionWindowEnd
,
SessionWindowStatus
:
account
.
SessionWindowStatus
,
AccountGroups
:
filterSchedulerAccountGroups
(
account
.
AccountGroups
),
GroupIDs
:
filterSchedulerGroupIDs
(
account
.
GroupIDs
,
account
.
AccountGroups
),
Credentials
:
filterSchedulerCredentials
(
account
.
Credentials
),
Extra
:
filterSchedulerExtra
(
account
.
Extra
),
}
}
func
filterSchedulerAccountGroups
(
accountGroups
[]
service
.
AccountGroup
)
[]
service
.
AccountGroup
{
if
len
(
accountGroups
)
==
0
{
return
nil
}
filtered
:=
make
([]
service
.
AccountGroup
,
0
,
len
(
accountGroups
))
for
_
,
ag
:=
range
accountGroups
{
if
ag
.
GroupID
<=
0
{
continue
}
filtered
=
append
(
filtered
,
service
.
AccountGroup
{
AccountID
:
ag
.
AccountID
,
GroupID
:
ag
.
GroupID
,
Priority
:
ag
.
Priority
,
CreatedAt
:
ag
.
CreatedAt
,
})
}
if
len
(
filtered
)
==
0
{
return
nil
}
return
filtered
}
func
filterSchedulerGroupIDs
(
groupIDs
[]
int64
,
accountGroups
[]
service
.
AccountGroup
)
[]
int64
{
if
len
(
groupIDs
)
==
0
&&
len
(
accountGroups
)
==
0
{
return
nil
}
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
groupIDs
)
+
len
(
accountGroups
))
filtered
:=
make
([]
int64
,
0
,
len
(
groupIDs
)
+
len
(
accountGroups
))
for
_
,
id
:=
range
groupIDs
{
if
id
<=
0
{
continue
}
if
_
,
ok
:=
seen
[
id
];
ok
{
continue
}
seen
[
id
]
=
struct
{}{}
filtered
=
append
(
filtered
,
id
)
}
for
_
,
ag
:=
range
accountGroups
{
if
ag
.
GroupID
<=
0
{
continue
}
if
_
,
ok
:=
seen
[
ag
.
GroupID
];
ok
{
continue
}
seen
[
ag
.
GroupID
]
=
struct
{}{}
filtered
=
append
(
filtered
,
ag
.
GroupID
)
}
if
len
(
filtered
)
==
0
{
return
nil
}
return
filtered
}
func
filterSchedulerCredentials
(
credentials
map
[
string
]
any
)
map
[
string
]
any
{
if
len
(
credentials
)
==
0
{
return
nil
...
...
backend/internal/repository/scheduler_cache_integration_test.go
View file @
538ae31a
...
...
@@ -56,6 +56,15 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
SessionWindowStart
:
&
now
,
SessionWindowEnd
:
&
windowEnd
,
SessionWindowStatus
:
"active"
,
GroupIDs
:
[]
int64
{
bucket
.
GroupID
},
AccountGroups
:
[]
service
.
AccountGroup
{
{
AccountID
:
101
,
GroupID
:
bucket
.
GroupID
,
Priority
:
5
,
Group
:
&
service
.
Group
{
ID
:
bucket
.
GroupID
,
Name
:
"gemini-group"
},
},
},
}
require
.
NoError
(
t
,
cache
.
SetSnapshot
(
ctx
,
bucket
,
[]
service
.
Account
{
account
}))
...
...
@@ -79,10 +88,17 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
require
.
Equal
(
t
,
4
,
got
.
GetMaxSessions
())
require
.
Equal
(
t
,
11
,
got
.
GetSessionIdleTimeoutMinutes
())
require
.
Nil
(
t
,
got
.
Extra
[
"unused_large_field"
])
require
.
Equal
(
t
,
[]
int64
{
bucket
.
GroupID
},
got
.
GroupIDs
)
require
.
Len
(
t
,
got
.
AccountGroups
,
1
)
require
.
Equal
(
t
,
account
.
ID
,
got
.
AccountGroups
[
0
]
.
AccountID
)
require
.
Equal
(
t
,
bucket
.
GroupID
,
got
.
AccountGroups
[
0
]
.
GroupID
)
require
.
Nil
(
t
,
got
.
AccountGroups
[
0
]
.
Group
)
full
,
err
:=
cache
.
GetAccount
(
ctx
,
account
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
full
)
require
.
Equal
(
t
,
"secret-access-token"
,
full
.
GetCredential
(
"access_token"
))
require
.
Equal
(
t
,
strings
.
Repeat
(
"x"
,
4096
),
full
.
GetCredential
(
"huge_blob"
))
require
.
Len
(
t
,
full
.
AccountGroups
,
1
)
require
.
NotNil
(
t
,
full
.
AccountGroups
[
0
]
.
Group
)
}
backend/internal/repository/scheduler_cache_unit_test.go
View file @
538ae31a
...
...
@@ -31,3 +31,43 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
require
.
Equal
(
t
,
true
,
got
.
Extra
[
"mixed_scheduling"
])
require
.
Nil
(
t
,
got
.
Extra
[
"unused_large_field"
])
}
func
TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership
(
t
*
testing
.
T
)
{
account
:=
service
.
Account
{
ID
:
42
,
Platform
:
service
.
PlatformAnthropic
,
GroupIDs
:
[]
int64
{
7
,
9
,
7
,
0
},
AccountGroups
:
[]
service
.
AccountGroup
{
{
AccountID
:
42
,
GroupID
:
7
,
Priority
:
2
,
Account
:
&
service
.
Account
{
ID
:
42
,
Name
:
"drop-from-metadata"
},
Group
:
&
service
.
Group
{
ID
:
7
,
Name
:
"drop-from-metadata"
},
},
{
AccountID
:
42
,
GroupID
:
11
,
Priority
:
3
,
Group
:
&
service
.
Group
{
ID
:
11
,
Name
:
"drop-from-metadata"
},
},
{
AccountID
:
42
,
GroupID
:
0
,
Priority
:
4
,
},
},
}
got
:=
buildSchedulerMetadataAccount
(
account
)
require
.
Equal
(
t
,
[]
int64
{
7
,
9
,
11
},
got
.
GroupIDs
)
require
.
Len
(
t
,
got
.
AccountGroups
,
2
)
require
.
Equal
(
t
,
int64
(
42
),
got
.
AccountGroups
[
0
]
.
AccountID
)
require
.
Equal
(
t
,
int64
(
7
),
got
.
AccountGroups
[
0
]
.
GroupID
)
require
.
Equal
(
t
,
2
,
got
.
AccountGroups
[
0
]
.
Priority
)
require
.
Nil
(
t
,
got
.
AccountGroups
[
0
]
.
Account
)
require
.
Nil
(
t
,
got
.
AccountGroups
[
0
]
.
Group
)
require
.
Equal
(
t
,
int64
(
11
),
got
.
AccountGroups
[
1
]
.
GroupID
)
require
.
Nil
(
t
,
got
.
Groups
)
}
backend/internal/server/api_contract_test.go
View file @
538ae31a
...
...
@@ -740,6 +740,7 @@ func TestAPIContracts(t *testing.T) {
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"enable_cch_signing": false,
"enable_anthropic_cache_ttl_1h_injection": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
...
...
@@ -748,6 +749,16 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": true,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
...
...
@@ -924,12 +935,23 @@ func TestAPIContracts(t *testing.T) {
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"enable_cch_signing": false,
"enable_anthropic_cache_ttl_1h_injection": false,
"web_search_emulation_enabled": false,
"payment_visible_method_alipay_source": "",
"payment_visible_method_wxpay_source": "",
"payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
...
...
backend/internal/service/account_test_service.go
View file @
538ae31a
...
...
@@ -64,6 +64,7 @@ func isOpenAIImageModel(model string) bool {
type
AccountTestService
struct
{
accountRepo
AccountRepository
geminiTokenProvider
*
GeminiTokenProvider
claudeTokenProvider
*
ClaudeTokenProvider
antigravityGatewayService
*
AntigravityGatewayService
httpUpstream
HTTPUpstream
cfg
*
config
.
Config
...
...
@@ -74,6 +75,7 @@ type AccountTestService struct {
func
NewAccountTestService
(
accountRepo
AccountRepository
,
geminiTokenProvider
*
GeminiTokenProvider
,
claudeTokenProvider
*
ClaudeTokenProvider
,
antigravityGatewayService
*
AntigravityGatewayService
,
httpUpstream
HTTPUpstream
,
cfg
*
config
.
Config
,
...
...
@@ -82,6 +84,7 @@ func NewAccountTestService(
return
&
AccountTestService
{
accountRepo
:
accountRepo
,
geminiTokenProvider
:
geminiTokenProvider
,
claudeTokenProvider
:
claudeTokenProvider
,
antigravityGatewayService
:
antigravityGatewayService
,
httpUpstream
:
httpUpstream
,
cfg
:
cfg
,
...
...
@@ -210,6 +213,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if
account
.
IsBedrock
()
{
return
s
.
testBedrockAccountConnection
(
c
,
ctx
,
account
,
testModelID
)
}
if
account
.
Type
==
AccountTypeServiceAccount
{
return
s
.
testClaudeVertexServiceAccountConnection
(
c
,
ctx
,
account
,
testModelID
)
}
// Determine authentication method and API URL
var
authToken
string
...
...
@@ -313,6 +319,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return
s
.
processClaudeStream
(
c
,
resp
.
Body
)
}
func
(
s
*
AccountTestService
)
testClaudeVertexServiceAccountConnection
(
c
*
gin
.
Context
,
ctx
context
.
Context
,
account
*
Account
,
testModelID
string
)
error
{
if
mappedModel
,
matched
:=
account
.
ResolveMappedModel
(
testModelID
);
matched
{
testModelID
=
mappedModel
}
else
{
testModelID
=
normalizeVertexAnthropicModelID
(
claude
.
NormalizeModelID
(
testModelID
))
}
c
.
Writer
.
Header
()
.
Set
(
"Content-Type"
,
"text/event-stream"
)
c
.
Writer
.
Header
()
.
Set
(
"Cache-Control"
,
"no-cache"
)
c
.
Writer
.
Header
()
.
Set
(
"Connection"
,
"keep-alive"
)
c
.
Writer
.
Header
()
.
Set
(
"X-Accel-Buffering"
,
"no"
)
c
.
Writer
.
Flush
()
payload
,
err
:=
createTestPayload
(
testModelID
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
"Failed to create test payload"
)
}
payloadBytes
,
_
:=
json
.
Marshal
(
payload
)
vertexBody
,
err
:=
buildVertexAnthropicRequestBody
(
payloadBytes
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Failed to create Vertex request body: %s"
,
err
.
Error
()))
}
if
s
.
claudeTokenProvider
==
nil
{
return
s
.
sendErrorAndEnd
(
c
,
"Claude token provider not configured"
)
}
accessToken
,
err
:=
s
.
claudeTokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Failed to get service account access token: %s"
,
err
.
Error
()))
}
fullURL
,
err
:=
buildVertexAnthropicURL
(
account
.
VertexProjectID
(),
account
.
VertexLocation
(
testModelID
),
testModelID
,
true
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Failed to build Vertex URL: %s"
,
err
.
Error
()))
}
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_start"
,
Model
:
testModelID
})
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fullURL
,
bytes
.
NewReader
(
vertexBody
))
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
"Failed to create request"
)
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
DoWithTLS
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
s
.
tlsFPProfileService
.
ResolveTLSProfile
(
account
))
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
errMsg
:=
fmt
.
Sprintf
(
"API returned %d: %s"
,
resp
.
StatusCode
,
string
(
body
))
if
resp
.
StatusCode
==
http
.
StatusForbidden
{
_
=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errMsg
)
}
return
s
.
sendErrorAndEnd
(
c
,
errMsg
)
}
return
s
.
processClaudeStream
(
c
,
resp
.
Body
)
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func
(
s
*
AccountTestService
)
testBedrockAccountConnection
(
c
*
gin
.
Context
,
ctx
context
.
Context
,
account
*
Account
,
testModelID
string
)
error
{
region
:=
bedrockRuntimeRegion
(
account
)
...
...
@@ -711,8 +785,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
testModelID
=
geminicli
.
DefaultTestModel
}
// For
API Key account
s with model mapping, map the model
if
account
.
Type
==
AccountTypeAPIKey
{
// For
static upstream credential
s with model mapping, map the model
if
account
.
Type
==
AccountTypeAPIKey
||
account
.
Type
==
AccountTypeServiceAccount
{
mapping
:=
account
.
GetModelMapping
()
if
len
(
mapping
)
>
0
{
if
mappedModel
,
exists
:=
mapping
[
testModelID
];
exists
{
...
...
@@ -740,6 +814,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
req
,
err
=
s
.
buildGeminiAPIKeyRequest
(
ctx
,
account
,
testModelID
,
payload
)
case
AccountTypeOAuth
:
req
,
err
=
s
.
buildGeminiOAuthRequest
(
ctx
,
account
,
testModelID
,
payload
)
case
AccountTypeServiceAccount
:
req
,
err
=
s
.
buildGeminiServiceAccountRequest
(
ctx
,
account
,
testModelID
,
payload
)
default
:
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Unsupported account type: %s"
,
account
.
Type
))
}
...
...
@@ -893,6 +969,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
return
s
.
buildCodeAssistRequest
(
ctx
,
accessToken
,
projectID
,
modelID
,
payload
)
}
func
(
s
*
AccountTestService
)
buildGeminiServiceAccountRequest
(
ctx
context
.
Context
,
account
*
Account
,
modelID
string
,
payload
[]
byte
)
(
*
http
.
Request
,
error
)
{
if
s
.
geminiTokenProvider
==
nil
{
return
nil
,
fmt
.
Errorf
(
"gemini token provider not configured"
)
}
accessToken
,
err
:=
s
.
geminiTokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get service account access token: %w"
,
err
)
}
fullURL
,
err
:=
buildVertexGeminiURL
(
account
.
VertexProjectID
(),
account
.
VertexLocation
(
modelID
),
modelID
,
"streamGenerateContent"
,
true
)
if
err
!=
nil
{
return
nil
,
err
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fullURL
,
bytes
.
NewReader
(
payload
))
if
err
!=
nil
{
return
nil
,
err
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
return
req
,
nil
}
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
func
(
s
*
AccountTestService
)
buildCodeAssistRequest
(
ctx
context
.
Context
,
accessToken
,
projectID
,
modelID
string
,
payload
[]
byte
)
(
*
http
.
Request
,
error
)
{
var
inner
map
[
string
]
any
...
...
@@ -1227,7 +1324,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Invalid base URL: %s"
,
err
.
Error
()))
}
apiURL
:=
strings
.
TrimSuffix
(
normalizedBaseURL
,
"/"
)
+
"/v1/i
mages
/g
enerations
"
apiURL
:=
buildOpenAIImagesURL
(
normalizedBaseURL
,
openAII
mages
G
enerations
Endpoint
)
// Set SSE headers
c
.
Writer
.
Header
()
.
Set
(
"Content-Type"
,
"text/event-stream"
)
...
...
backend/internal/service/account_test_service_openai_image_test.go
View file @
538ae31a
...
...
@@ -8,6 +8,7 @@ import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
...
...
@@ -48,3 +49,42 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"data:image/png;base64,aGVsbG8="
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"
\"
success
\"
:true"
)
}
func
TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/1/test"
,
nil
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
},
},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`
)),
},
}
svc
:=
&
AccountTestService
{
httpUpstream
:
upstream
,
cfg
:
&
config
.
Config
{},
}
account
:=
&
Account
{
ID
:
54
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"test-api-key"
,
"base_url"
:
"https://image-upstream.example/v1"
,
},
}
err
:=
svc
.
testOpenAIImageAPIKey
(
c
,
context
.
Background
(),
account
,
"gpt-image-2"
,
"draw a cat"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
upstream
.
lastReq
)
require
.
Equal
(
t
,
"https://image-upstream.example/v1/images/generations"
,
upstream
.
lastReq
.
URL
.
String
())
require
.
Equal
(
t
,
"Bearer test-api-key"
,
upstream
.
lastReq
.
Header
.
Get
(
"Authorization"
))
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"data:image/png;base64,aGVsbG8="
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"
\"
success
\"
:true"
)
}
backend/internal/service/admin_service.go
View file @
538ae31a
...
...
@@ -9,6 +9,7 @@ import (
"log/slog"
"net/http"
"sort"
"strconv"
"strings"
"time"
...
...
@@ -58,6 +59,7 @@ type AdminService interface {
// API Key management (admin)
AdminUpdateAPIKeyGroupID
(
ctx
context
.
Context
,
keyID
int64
,
groupID
*
int64
)
(
*
AdminUpdateAPIKeyGroupIDResult
,
error
)
AdminResetAPIKeyRateLimitUsage
(
ctx
context
.
Context
,
keyID
int64
)
(
*
APIKey
,
error
)
AdminSetCaptureRequests
(
ctx
context
.
Context
,
keyID
int64
,
enabled
bool
)
(
*
APIKey
,
error
)
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
...
...
@@ -292,6 +294,7 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type
BulkUpdateAccountsInput
struct
{
AccountIDs
[]
int64
Filters
*
BulkUpdateAccountFilters
Name
string
ProxyID
*
int64
Concurrency
*
int
...
...
@@ -308,6 +311,15 @@ type BulkUpdateAccountsInput struct {
SkipMixedChannelCheck
bool
}
type
BulkUpdateAccountFilters
struct
{
Platform
string
Type
string
Status
string
Group
string
Search
string
PrivacyMode
string
}
// BulkUpdateAccountResult captures the result for a single account update.
type
BulkUpdateAccountResult
struct
{
AccountID
int64
`json:"account_id"`
...
...
@@ -1962,6 +1974,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return
result
,
nil
}
// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows.
func
(
s
*
adminServiceImpl
)
AdminResetAPIKeyRateLimitUsage
(
ctx
context
.
Context
,
keyID
int64
)
(
*
APIKey
,
error
)
{
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByID
(
ctx
,
keyID
)
if
err
!=
nil
{
return
nil
,
err
}
apiKey
.
Usage5h
=
0
apiKey
.
Usage1d
=
0
apiKey
.
Usage7d
=
0
apiKey
.
Window5hStart
=
nil
apiKey
.
Window1dStart
=
nil
apiKey
.
Window7dStart
=
nil
if
err
:=
s
.
apiKeyRepo
.
Update
(
ctx
,
apiKey
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"reset api key rate limit usage: %w"
,
err
)
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByKey
(
ctx
,
apiKey
.
Key
)
}
if
s
.
billingCacheService
!=
nil
{
_
=
s
.
billingCacheService
.
InvalidateAPIKeyRateLimit
(
ctx
,
apiKey
.
ID
)
}
return
apiKey
,
nil
}
// AdminSetCaptureRequests 设置或清除指定 API Key 的请求捕获开关,并立即失效认证缓存。
func
(
s
*
adminServiceImpl
)
AdminSetCaptureRequests
(
ctx
context
.
Context
,
keyID
int64
,
enabled
bool
)
(
*
APIKey
,
error
)
{
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByID
(
ctx
,
keyID
)
...
...
@@ -2303,6 +2339,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func
(
s
*
adminServiceImpl
)
BulkUpdateAccounts
(
ctx
context
.
Context
,
input
*
BulkUpdateAccountsInput
)
(
*
BulkUpdateAccountsResult
,
error
)
{
if
len
(
input
.
AccountIDs
)
==
0
&&
input
.
Filters
!=
nil
{
accountIDs
,
err
:=
s
.
resolveBulkUpdateTargetIDs
(
ctx
,
input
.
Filters
)
if
err
!=
nil
{
return
nil
,
err
}
input
.
AccountIDs
=
accountIDs
}
result
:=
&
BulkUpdateAccountsResult
{
SuccessIDs
:
make
([]
int64
,
0
,
len
(
input
.
AccountIDs
)),
FailedIDs
:
make
([]
int64
,
0
,
len
(
input
.
AccountIDs
)),
...
...
@@ -2418,6 +2462,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return
result
,
nil
}
func
(
s
*
adminServiceImpl
)
resolveBulkUpdateTargetIDs
(
ctx
context
.
Context
,
filters
*
BulkUpdateAccountFilters
)
([]
int64
,
error
)
{
if
filters
==
nil
{
return
nil
,
nil
}
groupID
:=
int64
(
0
)
switch
strings
.
TrimSpace
(
filters
.
Group
)
{
case
""
:
case
"ungrouped"
:
groupID
=
AccountListGroupUngrouped
default
:
parsedGroupID
,
err
:=
strconv
.
ParseInt
(
strings
.
TrimSpace
(
filters
.
Group
),
10
,
64
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid group filter: %w"
,
err
)
}
groupID
=
parsedGroupID
}
const
pageSize
=
500
page
:=
1
accountIDs
:=
make
([]
int64
,
0
,
pageSize
)
for
{
accounts
,
total
,
err
:=
s
.
ListAccounts
(
ctx
,
page
,
pageSize
,
filters
.
Platform
,
filters
.
Type
,
filters
.
Status
,
filters
.
Search
,
groupID
,
filters
.
PrivacyMode
,
""
,
""
,
)
if
err
!=
nil
{
return
nil
,
err
}
for
_
,
account
:=
range
accounts
{
accountIDs
=
append
(
accountIDs
,
account
.
ID
)
}
if
int64
(
len
(
accountIDs
))
>=
total
||
len
(
accounts
)
==
0
{
return
accountIDs
,
nil
}
page
++
}
}
func
(
s
*
adminServiceImpl
)
DeleteAccount
(
ctx
context
.
Context
,
id
int64
)
error
{
if
err
:=
s
.
accountRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
err
...
...
backend/internal/service/admin_service_bulk_update_test.go
View file @
538ae31a
...
...
@@ -5,8 +5,10 @@ package service
import
(
"context"
"errors"
"reflect"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
...
...
@@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct {
getByIDCalled
[]
int64
listByGroupData
map
[
int64
][]
Account
listByGroupErr
map
[
int64
]
error
listData
[]
Account
listResult
*
pagination
.
PaginationResult
listErr
error
listCalled
bool
lastListParams
pagination
.
PaginationParams
lastListFilters
struct
{
platform
string
accountType
string
status
string
search
string
groupID
int64
privacyMode
string
}
}
func
(
s
*
accountRepoStubForBulkUpdate
)
BulkUpdate
(
_
context
.
Context
,
ids
[]
int64
,
_
AccountBulkUpdate
)
(
int64
,
error
)
{
...
...
@@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in
return
nil
,
nil
}
func
(
s
*
accountRepoStubForBulkUpdate
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
,
groupID
int64
,
privacyMode
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listCalled
=
true
s
.
lastListParams
=
params
s
.
lastListFilters
.
platform
=
platform
s
.
lastListFilters
.
accountType
=
accountType
s
.
lastListFilters
.
status
=
status
s
.
lastListFilters
.
search
=
search
s
.
lastListFilters
.
groupID
=
groupID
s
.
lastListFilters
.
privacyMode
=
privacyMode
if
s
.
listErr
!=
nil
{
return
nil
,
nil
,
s
.
listErr
}
if
s
.
listResult
!=
nil
{
return
s
.
listData
,
s
.
listResult
,
nil
}
return
s
.
listData
,
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listData
))},
nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func
TestAdminService_BulkUpdateAccounts_AllSuccessIDs
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForBulkUpdate
{}
...
...
@@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon
// No BindGroups should have been called since the check runs before any write.
require
.
Empty
(
t
,
repo
.
bindGroupsCalls
)
}
func
TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForBulkUpdate
{
listData
:
[]
Account
{
{
ID
:
7
},
{
ID
:
11
},
},
listResult
:
&
pagination
.
PaginationResult
{
Total
:
2
},
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
schedulable
:=
true
input
:=
&
BulkUpdateAccountsInput
{
Schedulable
:
&
schedulable
,
}
filtersField
:=
reflect
.
ValueOf
(
input
)
.
Elem
()
.
FieldByName
(
"Filters"
)
require
.
True
(
t
,
filtersField
.
IsValid
(),
"BulkUpdateAccountsInput should expose Filters for filter-target bulk update"
)
require
.
Equal
(
t
,
reflect
.
Ptr
,
filtersField
.
Kind
(),
"BulkUpdateAccountsInput.Filters should be a pointer field"
)
filtersValue
:=
reflect
.
New
(
filtersField
.
Type
()
.
Elem
())
filtersValue
.
Elem
()
.
FieldByName
(
"Platform"
)
.
SetString
(
PlatformOpenAI
)
filtersValue
.
Elem
()
.
FieldByName
(
"Type"
)
.
SetString
(
AccountTypeOAuth
)
filtersValue
.
Elem
()
.
FieldByName
(
"Status"
)
.
SetString
(
StatusActive
)
filtersValue
.
Elem
()
.
FieldByName
(
"Group"
)
.
SetString
(
"12"
)
filtersValue
.
Elem
()
.
FieldByName
(
"PrivacyMode"
)
.
SetString
(
PrivacyModeCFBlocked
)
filtersValue
.
Elem
()
.
FieldByName
(
"Search"
)
.
SetString
(
"bulk-target"
)
filtersField
.
Set
(
filtersValue
)
result
,
err
:=
svc
.
BulkUpdateAccounts
(
context
.
Background
(),
input
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
repo
.
listCalled
,
"expected filter-target bulk update to resolve matching IDs via account list filters"
)
require
.
Equal
(
t
,
PlatformOpenAI
,
repo
.
lastListFilters
.
platform
)
require
.
Equal
(
t
,
AccountTypeOAuth
,
repo
.
lastListFilters
.
accountType
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
lastListFilters
.
status
)
require
.
Equal
(
t
,
"bulk-target"
,
repo
.
lastListFilters
.
search
)
require
.
Equal
(
t
,
int64
(
12
),
repo
.
lastListFilters
.
groupID
)
require
.
Equal
(
t
,
PrivacyModeCFBlocked
,
repo
.
lastListFilters
.
privacyMode
)
require
.
Equal
(
t
,
[]
int64
{
7
,
11
},
repo
.
bulkUpdateIDs
)
require
.
Equal
(
t
,
2
,
result
.
Success
)
require
.
Equal
(
t
,
0
,
result
.
Failed
)
require
.
Equal
(
t
,
[]
int64
{
7
,
11
},
result
.
SuccessIDs
)
}
backend/internal/service/billing_cache_service.go
View file @
538ae31a
...
...
@@ -508,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
return
nil
}
// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key.
func
(
s
*
BillingCacheService
)
InvalidateAPIKeyRateLimit
(
ctx
context
.
Context
,
keyID
int64
)
error
{
if
s
.
cache
==
nil
{
return
nil
}
if
err
:=
s
.
cache
.
InvalidateAPIKeyRateLimit
(
ctx
,
keyID
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.billing_cache"
,
"Warning: invalidate api key rate limit cache failed for key %d: %v"
,
keyID
,
err
)
return
err
}
return
nil
}
// ============================================
// API Key 限速缓存方法
// ============================================
...
...
backend/internal/service/claude_token_provider.go
View file @
538ae31a
...
...
@@ -17,7 +17,7 @@ const (
// ClaudeTokenCache token cache interface.
type
ClaudeTokenCache
=
GeminiTokenCache
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
// ClaudeTokenProvider manages access_token for Claude OAuth
and Vertex service account
accounts.
type
ClaudeTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
ClaudeTokenCache
...
...
@@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAnthropic
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an anthropic oauth account"
)
if
account
.
Platform
!=
PlatformAnthropic
||
(
account
.
Type
!=
AccountTypeOAuth
&&
account
.
Type
!=
AccountTypeServiceAccount
)
{
return
""
,
errors
.
New
(
"not an anthropic oauth or service account"
)
}
if
account
.
Type
==
AccountTypeServiceAccount
{
return
p
.
getServiceAccountAccessToken
(
ctx
,
account
)
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
...
...
@@ -157,3 +160,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
accessToken
,
nil
}
func
(
p
*
ClaudeTokenProvider
)
getServiceAccountAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
return
getVertexServiceAccountAccessToken
(
ctx
,
p
.
tokenCache
,
account
)
}
backend/internal/service/claude_token_provider_test.go
View file @
538ae31a
...
...
@@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAnthropic
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an anthropic oauth account"
)
return
""
,
errors
.
New
(
"not an anthropic oauth
or service
account"
)
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
...
...
@@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth
or service
account"
)
require
.
Empty
(
t
,
token
)
}
...
...
@@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth
or service
account"
)
require
.
Empty
(
t
,
token
)
}
...
...
@@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth
or service
account"
)
require
.
Empty
(
t
,
token
)
}
...
...
backend/internal/service/domain_constants.go
View file @
538ae31a
...
...
@@ -41,11 +41,12 @@ const (
// Account type constants
const
(
AccountTypeOAuth
=
domain
.
AccountTypeOAuth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
domain
.
AccountTypeSetupToken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
AccountTypeUpstream
=
domain
.
AccountTypeUpstream
// 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock
=
domain
.
AccountTypeBedrock
// AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeOAuth
=
domain
.
AccountTypeOAuth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
domain
.
AccountTypeSetupToken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
AccountTypeUpstream
=
domain
.
AccountTypeUpstream
// 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock
=
domain
.
AccountTypeBedrock
// AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeServiceAccount
=
domain
.
AccountTypeServiceAccount
// Google Service Account 类型账号(用于 Vertex AI)
)
// Redeem type constants
...
...
@@ -306,6 +307,12 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings
=
"beta_policy_settings"
// SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
// service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
// targets OpenAI's body-level service_tier field instead of Claude's
// anthropic-beta header.
SettingKeyOpenAIFastPolicySettings
=
"openai_fast_policy_settings"
// =========================
// Claude Code Version Check
// =========================
...
...
@@ -329,6 +336,8 @@ const (
SettingKeyEnableMetadataPassthrough
=
"enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning
=
"enable_cch_signing"
// SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
SettingKeyEnableAnthropicCacheTTL1hInjection
=
"enable_anthropic_cache_ttl_1h_injection"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled
=
"balance_low_notify_enabled"
// 全局开关
...
...
backend/internal/service/gateway_anthropic_vertex_service_account_test.go
0 → 100644
View file @
538ae31a
package
service
import
(
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func
TestGatewayService_BuildAnthropicVertexServiceAccountRequest
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
c
.
Request
.
Header
.
Set
(
"Authorization"
,
"Bearer inbound-token"
)
c
.
Request
.
Header
.
Set
(
"X-Api-Key"
,
"inbound-api-key"
)
c
.
Request
.
Header
.
Set
(
"Anthropic-Version"
,
"2023-06-01"
)
c
.
Request
.
Header
.
Set
(
"Anthropic-Beta"
,
"interleaved-thinking-2025-05-14"
)
account
:=
&
Account
{
ID
:
301
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeServiceAccount
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"vertex-proj"
,
"location"
:
"us-east5"
,
},
}
body
:=
[]
byte
(
`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`
)
svc
:=
&
GatewayService
{}
req
,
err
:=
svc
.
buildUpstreamRequest
(
context
.
Background
(),
c
,
account
,
body
,
"vertex-token"
,
"service_account"
,
"claude-sonnet-4-5@20250929"
,
false
,
false
,
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict"
,
req
.
URL
.
String
())
require
.
Equal
(
t
,
"Bearer vertex-token"
,
getHeaderRaw
(
req
.
Header
,
"authorization"
))
require
.
Empty
(
t
,
getHeaderRaw
(
req
.
Header
,
"x-api-key"
))
require
.
Empty
(
t
,
getHeaderRaw
(
req
.
Header
,
"anthropic-version"
))
require
.
Equal
(
t
,
"interleaved-thinking-2025-05-14"
,
getHeaderRaw
(
req
.
Header
,
"anthropic-beta"
))
got
:=
readRequestBodyForTest
(
t
,
req
)
require
.
Equal
(
t
,
""
,
gjson
.
GetBytes
(
got
,
"model"
)
.
String
())
require
.
Equal
(
t
,
vertexAnthropicVersion
,
gjson
.
GetBytes
(
got
,
"anthropic_version"
)
.
String
())
require
.
Equal
(
t
,
"hello"
,
gjson
.
GetBytes
(
got
,
"messages.0.content"
)
.
String
())
}
func
readRequestBodyForTest
(
t
*
testing
.
T
,
req
*
http
.
Request
)
[]
byte
{
t
.
Helper
()
require
.
NotNil
(
t
,
req
.
Body
)
body
,
err
:=
io
.
ReadAll
(
req
.
Body
)
require
.
NoError
(
t
,
err
)
return
body
}
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment