Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
e83f0ee3
Commit
e83f0ee3
authored
Dec 30, 2025
by
yangjianbo
Browse files
Merge branch 'main' into test-dev
parents
bff3c66d
942c3e15
Changes
61
Show whitespace changes
Inline
Side-by-side
backend/internal/pkg/antigravity/stream_transformer.go
0 → 100644
View file @
e83f0ee3
package
antigravity
import
(
"bytes"
"encoding/json"
"fmt"
"strings"
)
// BlockType 内容块类型
type
BlockType
int
const
(
BlockTypeNone
BlockType
=
iota
BlockTypeText
BlockTypeThinking
BlockTypeFunction
)
// StreamingProcessor 流式响应处理器
type
StreamingProcessor
struct
{
blockType
BlockType
blockIndex
int
messageStartSent
bool
messageStopSent
bool
usedTool
bool
pendingSignature
string
trailingSignature
string
originalModel
string
// 累计 usage
inputTokens
int
outputTokens
int
}
// NewStreamingProcessor 创建流式响应处理器
func
NewStreamingProcessor
(
originalModel
string
)
*
StreamingProcessor
{
return
&
StreamingProcessor
{
blockType
:
BlockTypeNone
,
originalModel
:
originalModel
,
}
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func
(
p
*
StreamingProcessor
)
ProcessLine
(
line
string
)
[]
byte
{
line
=
strings
.
TrimSpace
(
line
)
if
line
==
""
||
!
strings
.
HasPrefix
(
line
,
"data:"
)
{
return
nil
}
data
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
line
,
"data:"
))
if
data
==
""
||
data
==
"[DONE]"
{
return
nil
}
// 解包 v1internal 响应
var
v1Resp
V1InternalResponse
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
v1Resp
);
err
!=
nil
{
// 尝试直接解析为 GeminiResponse
var
directResp
GeminiResponse
if
err2
:=
json
.
Unmarshal
([]
byte
(
data
),
&
directResp
);
err2
!=
nil
{
return
nil
}
v1Resp
.
Response
=
directResp
v1Resp
.
ResponseID
=
directResp
.
ResponseID
v1Resp
.
ModelVersion
=
directResp
.
ModelVersion
}
geminiResp
:=
&
v1Resp
.
Response
var
result
bytes
.
Buffer
// 发送 message_start
if
!
p
.
messageStartSent
{
_
,
_
=
result
.
Write
(
p
.
emitMessageStart
(
&
v1Resp
))
}
// 更新 usage
if
geminiResp
.
UsageMetadata
!=
nil
{
p
.
inputTokens
=
geminiResp
.
UsageMetadata
.
PromptTokenCount
p
.
outputTokens
=
geminiResp
.
UsageMetadata
.
CandidatesTokenCount
}
// 处理 parts
if
len
(
geminiResp
.
Candidates
)
>
0
&&
geminiResp
.
Candidates
[
0
]
.
Content
!=
nil
{
for
_
,
part
:=
range
geminiResp
.
Candidates
[
0
]
.
Content
.
Parts
{
_
,
_
=
result
.
Write
(
p
.
processPart
(
&
part
))
}
}
// 检查是否结束
if
len
(
geminiResp
.
Candidates
)
>
0
{
finishReason
:=
geminiResp
.
Candidates
[
0
]
.
FinishReason
if
finishReason
!=
""
{
_
,
_
=
result
.
Write
(
p
.
emitFinish
(
finishReason
))
}
}
return
result
.
Bytes
()
}
// Finish 结束处理,返回最终事件和用量
func
(
p
*
StreamingProcessor
)
Finish
()
([]
byte
,
*
ClaudeUsage
)
{
var
result
bytes
.
Buffer
if
!
p
.
messageStopSent
{
_
,
_
=
result
.
Write
(
p
.
emitFinish
(
""
))
}
usage
:=
&
ClaudeUsage
{
InputTokens
:
p
.
inputTokens
,
OutputTokens
:
p
.
outputTokens
,
}
return
result
.
Bytes
(),
usage
}
// emitMessageStart 发送 message_start 事件
func
(
p
*
StreamingProcessor
)
emitMessageStart
(
v1Resp
*
V1InternalResponse
)
[]
byte
{
if
p
.
messageStartSent
{
return
nil
}
usage
:=
ClaudeUsage
{}
if
v1Resp
.
Response
.
UsageMetadata
!=
nil
{
usage
.
InputTokens
=
v1Resp
.
Response
.
UsageMetadata
.
PromptTokenCount
usage
.
OutputTokens
=
v1Resp
.
Response
.
UsageMetadata
.
CandidatesTokenCount
}
responseID
:=
v1Resp
.
ResponseID
if
responseID
==
""
{
responseID
=
v1Resp
.
Response
.
ResponseID
}
if
responseID
==
""
{
responseID
=
"msg_"
+
generateRandomID
()
}
message
:=
map
[
string
]
any
{
"id"
:
responseID
,
"type"
:
"message"
,
"role"
:
"assistant"
,
"content"
:
[]
any
{},
"model"
:
p
.
originalModel
,
"stop_reason"
:
nil
,
"stop_sequence"
:
nil
,
"usage"
:
usage
,
}
event
:=
map
[
string
]
any
{
"type"
:
"message_start"
,
"message"
:
message
,
}
p
.
messageStartSent
=
true
return
p
.
formatSSE
(
"message_start"
,
event
)
}
// processPart 处理单个 part
func
(
p
*
StreamingProcessor
)
processPart
(
part
*
GeminiPart
)
[]
byte
{
var
result
bytes
.
Buffer
signature
:=
part
.
ThoughtSignature
// 1. FunctionCall 处理
if
part
.
FunctionCall
!=
nil
{
// 先处理 trailingSignature
if
p
.
trailingSignature
!=
""
{
_
,
_
=
result
.
Write
(
p
.
endBlock
())
_
,
_
=
result
.
Write
(
p
.
emitEmptyThinkingWithSignature
(
p
.
trailingSignature
))
p
.
trailingSignature
=
""
}
_
,
_
=
result
.
Write
(
p
.
processFunctionCall
(
part
.
FunctionCall
,
signature
))
return
result
.
Bytes
()
}
// 2. Text 处理
if
part
.
Text
!=
""
||
part
.
Thought
{
if
part
.
Thought
{
_
,
_
=
result
.
Write
(
p
.
processThinking
(
part
.
Text
,
signature
))
}
else
{
_
,
_
=
result
.
Write
(
p
.
processText
(
part
.
Text
,
signature
))
}
}
// 3. InlineData (Image) 处理
if
part
.
InlineData
!=
nil
&&
part
.
InlineData
.
Data
!=
""
{
markdownImg
:=
fmt
.
Sprintf
(
""
,
part
.
InlineData
.
MimeType
,
part
.
InlineData
.
Data
)
_
,
_
=
result
.
Write
(
p
.
processText
(
markdownImg
,
""
))
}
return
result
.
Bytes
()
}
// processThinking 处理 thinking
func
(
p
*
StreamingProcessor
)
processThinking
(
text
,
signature
string
)
[]
byte
{
var
result
bytes
.
Buffer
// 处理之前的 trailingSignature
if
p
.
trailingSignature
!=
""
{
_
,
_
=
result
.
Write
(
p
.
endBlock
())
_
,
_
=
result
.
Write
(
p
.
emitEmptyThinkingWithSignature
(
p
.
trailingSignature
))
p
.
trailingSignature
=
""
}
// 开始或继续 thinking 块
if
p
.
blockType
!=
BlockTypeThinking
{
_
,
_
=
result
.
Write
(
p
.
startBlock
(
BlockTypeThinking
,
map
[
string
]
any
{
"type"
:
"thinking"
,
"thinking"
:
""
,
}))
}
if
text
!=
""
{
_
,
_
=
result
.
Write
(
p
.
emitDelta
(
"thinking_delta"
,
map
[
string
]
any
{
"thinking"
:
text
,
}))
}
// 暂存签名
if
signature
!=
""
{
p
.
pendingSignature
=
signature
}
return
result
.
Bytes
()
}
// processText 处理普通 text
func
(
p
*
StreamingProcessor
)
processText
(
text
,
signature
string
)
[]
byte
{
var
result
bytes
.
Buffer
// 空 text 带签名 - 暂存
if
text
==
""
{
if
signature
!=
""
{
p
.
trailingSignature
=
signature
}
return
nil
}
// 处理之前的 trailingSignature
if
p
.
trailingSignature
!=
""
{
_
,
_
=
result
.
Write
(
p
.
endBlock
())
_
,
_
=
result
.
Write
(
p
.
emitEmptyThinkingWithSignature
(
p
.
trailingSignature
))
p
.
trailingSignature
=
""
}
// 非空 text 带签名 - 特殊处理
if
signature
!=
""
{
_
,
_
=
result
.
Write
(
p
.
startBlock
(
BlockTypeText
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
""
,
}))
_
,
_
=
result
.
Write
(
p
.
emitDelta
(
"text_delta"
,
map
[
string
]
any
{
"text"
:
text
,
}))
_
,
_
=
result
.
Write
(
p
.
endBlock
())
_
,
_
=
result
.
Write
(
p
.
emitEmptyThinkingWithSignature
(
signature
))
return
result
.
Bytes
()
}
// 普通 text (无签名)
if
p
.
blockType
!=
BlockTypeText
{
_
,
_
=
result
.
Write
(
p
.
startBlock
(
BlockTypeText
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
""
,
}))
}
_
,
_
=
result
.
Write
(
p
.
emitDelta
(
"text_delta"
,
map
[
string
]
any
{
"text"
:
text
,
}))
return
result
.
Bytes
()
}
// processFunctionCall 处理 function call
func
(
p
*
StreamingProcessor
)
processFunctionCall
(
fc
*
GeminiFunctionCall
,
signature
string
)
[]
byte
{
var
result
bytes
.
Buffer
p
.
usedTool
=
true
toolID
:=
fc
.
ID
if
toolID
==
""
{
toolID
=
fmt
.
Sprintf
(
"%s-%s"
,
fc
.
Name
,
generateRandomID
())
}
toolUse
:=
map
[
string
]
any
{
"type"
:
"tool_use"
,
"id"
:
toolID
,
"name"
:
fc
.
Name
,
"input"
:
map
[
string
]
any
{},
}
if
signature
!=
""
{
toolUse
[
"signature"
]
=
signature
}
_
,
_
=
result
.
Write
(
p
.
startBlock
(
BlockTypeFunction
,
toolUse
))
// 发送 input_json_delta
if
fc
.
Args
!=
nil
{
argsJSON
,
_
:=
json
.
Marshal
(
fc
.
Args
)
_
,
_
=
result
.
Write
(
p
.
emitDelta
(
"input_json_delta"
,
map
[
string
]
any
{
"partial_json"
:
string
(
argsJSON
),
}))
}
_
,
_
=
result
.
Write
(
p
.
endBlock
())
return
result
.
Bytes
()
}
// startBlock 开始新的内容块
func
(
p
*
StreamingProcessor
)
startBlock
(
blockType
BlockType
,
contentBlock
map
[
string
]
any
)
[]
byte
{
var
result
bytes
.
Buffer
if
p
.
blockType
!=
BlockTypeNone
{
_
,
_
=
result
.
Write
(
p
.
endBlock
())
}
event
:=
map
[
string
]
any
{
"type"
:
"content_block_start"
,
"index"
:
p
.
blockIndex
,
"content_block"
:
contentBlock
,
}
_
,
_
=
result
.
Write
(
p
.
formatSSE
(
"content_block_start"
,
event
))
p
.
blockType
=
blockType
return
result
.
Bytes
()
}
// endBlock 结束当前内容块
func
(
p
*
StreamingProcessor
)
endBlock
()
[]
byte
{
if
p
.
blockType
==
BlockTypeNone
{
return
nil
}
var
result
bytes
.
Buffer
// Thinking 块结束时发送暂存的签名
if
p
.
blockType
==
BlockTypeThinking
&&
p
.
pendingSignature
!=
""
{
_
,
_
=
result
.
Write
(
p
.
emitDelta
(
"signature_delta"
,
map
[
string
]
any
{
"signature"
:
p
.
pendingSignature
,
}))
p
.
pendingSignature
=
""
}
event
:=
map
[
string
]
any
{
"type"
:
"content_block_stop"
,
"index"
:
p
.
blockIndex
,
}
_
,
_
=
result
.
Write
(
p
.
formatSSE
(
"content_block_stop"
,
event
))
p
.
blockIndex
++
p
.
blockType
=
BlockTypeNone
return
result
.
Bytes
()
}
// emitDelta 发送 delta 事件
func
(
p
*
StreamingProcessor
)
emitDelta
(
deltaType
string
,
deltaContent
map
[
string
]
any
)
[]
byte
{
delta
:=
map
[
string
]
any
{
"type"
:
deltaType
,
}
for
k
,
v
:=
range
deltaContent
{
delta
[
k
]
=
v
}
event
:=
map
[
string
]
any
{
"type"
:
"content_block_delta"
,
"index"
:
p
.
blockIndex
,
"delta"
:
delta
,
}
return
p
.
formatSSE
(
"content_block_delta"
,
event
)
}
// emitEmptyThinkingWithSignature 发送空 thinking 块承载签名
func
(
p
*
StreamingProcessor
)
emitEmptyThinkingWithSignature
(
signature
string
)
[]
byte
{
var
result
bytes
.
Buffer
_
,
_
=
result
.
Write
(
p
.
startBlock
(
BlockTypeThinking
,
map
[
string
]
any
{
"type"
:
"thinking"
,
"thinking"
:
""
,
}))
_
,
_
=
result
.
Write
(
p
.
emitDelta
(
"thinking_delta"
,
map
[
string
]
any
{
"thinking"
:
""
,
}))
_
,
_
=
result
.
Write
(
p
.
emitDelta
(
"signature_delta"
,
map
[
string
]
any
{
"signature"
:
signature
,
}))
_
,
_
=
result
.
Write
(
p
.
endBlock
())
return
result
.
Bytes
()
}
// emitFinish 发送结束事件
func
(
p
*
StreamingProcessor
)
emitFinish
(
finishReason
string
)
[]
byte
{
var
result
bytes
.
Buffer
// 关闭最后一个块
_
,
_
=
result
.
Write
(
p
.
endBlock
())
// 处理 trailingSignature
if
p
.
trailingSignature
!=
""
{
_
,
_
=
result
.
Write
(
p
.
emitEmptyThinkingWithSignature
(
p
.
trailingSignature
))
p
.
trailingSignature
=
""
}
// 确定 stop_reason
stopReason
:=
"end_turn"
if
p
.
usedTool
{
stopReason
=
"tool_use"
}
else
if
finishReason
==
"MAX_TOKENS"
{
stopReason
=
"max_tokens"
}
usage
:=
ClaudeUsage
{
InputTokens
:
p
.
inputTokens
,
OutputTokens
:
p
.
outputTokens
,
}
deltaEvent
:=
map
[
string
]
any
{
"type"
:
"message_delta"
,
"delta"
:
map
[
string
]
any
{
"stop_reason"
:
stopReason
,
"stop_sequence"
:
nil
,
},
"usage"
:
usage
,
}
_
,
_
=
result
.
Write
(
p
.
formatSSE
(
"message_delta"
,
deltaEvent
))
if
!
p
.
messageStopSent
{
stopEvent
:=
map
[
string
]
any
{
"type"
:
"message_stop"
,
}
_
,
_
=
result
.
Write
(
p
.
formatSSE
(
"message_stop"
,
stopEvent
))
p
.
messageStopSent
=
true
}
return
result
.
Bytes
()
}
// formatSSE 格式化 SSE 事件
func
(
p
*
StreamingProcessor
)
formatSSE
(
eventType
string
,
data
any
)
[]
byte
{
jsonData
,
err
:=
json
.
Marshal
(
data
)
if
err
!=
nil
{
return
nil
}
return
[]
byte
(
fmt
.
Sprintf
(
"event: %s
\n
data: %s
\n\n
"
,
eventType
,
string
(
jsonData
)))
}
backend/internal/pkg/ctxkey/ctxkey.go
0 → 100644
View file @
e83f0ee3
// Package ctxkey 定义用于 context.Value 的类型安全 key
package
ctxkey
// Key 定义 context key 的类型,避免使用内置 string 类型(staticcheck SA1029)
type
Key
string
const
(
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform
Key
=
"ctx_force_platform"
)
backend/internal/repository/account_repo.go
View file @
e83f0ee3
...
...
@@ -464,6 +464,56 @@ func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
})
}
func
(
r
*
accountRepository
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
if
len
(
platforms
)
==
0
{
return
nil
,
nil
}
var
accounts
[]
accountModel
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"platform IN ?"
,
platforms
)
.
Where
(
"status = ? AND schedulable = ?"
,
service
.
StatusActive
,
true
)
.
Where
(
"(overload_until IS NULL OR overload_until <= ?)"
,
now
)
.
Where
(
"(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
func
(
r
*
accountRepository
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
service
.
Account
,
error
)
{
if
len
(
platforms
)
==
0
{
return
nil
,
nil
}
var
accounts
[]
accountModel
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Where
(
"account_groups.group_id = ?"
,
groupID
)
.
Where
(
"accounts.platform IN ?"
,
platforms
)
.
Where
(
"accounts.status = ? AND accounts.schedulable = ?"
,
service
.
StatusActive
,
true
)
.
Where
(
"(accounts.overload_until IS NULL OR accounts.overload_until <= ?)"
,
now
)
.
Where
(
"(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Find
(
&
accounts
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
outAccounts
:=
make
([]
service
.
Account
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
outAccounts
=
append
(
outAccounts
,
*
accountModelToService
(
&
accounts
[
i
]))
}
return
outAccounts
,
nil
}
func
(
r
*
accountRepository
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
now
:=
time
.
Now
()
_
,
err
:=
r
.
client
.
Account
.
Update
()
.
...
...
backend/internal/repository/gateway_routing_integration_test.go
0 → 100644
View file @
e83f0ee3
//go:build integration
package
repository
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
// GatewayRoutingSuite 测试网关路由相关的数据库查询
// 验证账户选择和分流逻辑在真实数据库环境下的行为
type
GatewayRoutingSuite
struct
{
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
accountRepo
*
accountRepository
}
func
(
s
*
GatewayRoutingSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
accountRepo
=
NewAccountRepository
(
s
.
db
)
.
(
*
accountRepository
)
}
func
TestGatewayRoutingSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GatewayRoutingSuite
))
}
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
func
(
s
*
GatewayRoutingSuite
)
TestListSchedulableByPlatforms_GeminiAndAntigravity
()
{
// 创建各平台账户
geminiAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"gemini-oauth"
,
Platform
:
service
.
PlatformGemini
,
Type
:
service
.
AccountTypeOAuth
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Priority
:
1
,
})
antigravityAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"antigravity-oauth"
,
Platform
:
service
.
PlatformAntigravity
,
Type
:
service
.
AccountTypeOAuth
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Priority
:
2
,
Credentials
:
datatypes
.
JSONMap
{
"access_token"
:
"test-token"
,
"refresh_token"
:
"test-refresh"
,
"project_id"
:
"test-project"
,
},
})
// 创建不应被选中的 anthropic 账户
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"anthropic-oauth"
,
Platform
:
service
.
PlatformAnthropic
,
Type
:
service
.
AccountTypeOAuth
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Priority
:
0
,
})
// 查询 gemini + antigravity 平台
accounts
,
err
:=
s
.
accountRepo
.
ListSchedulableByPlatforms
(
s
.
ctx
,
[]
string
{
service
.
PlatformGemini
,
service
.
PlatformAntigravity
,
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
2
,
"应返回 gemini 和 antigravity 两个账户"
)
// 验证返回的账户平台
platforms
:=
make
(
map
[
string
]
bool
)
for
_
,
acc
:=
range
accounts
{
platforms
[
acc
.
Platform
]
=
true
}
s
.
Require
()
.
True
(
platforms
[
service
.
PlatformGemini
],
"应包含 gemini 账户"
)
s
.
Require
()
.
True
(
platforms
[
service
.
PlatformAntigravity
],
"应包含 antigravity 账户"
)
s
.
Require
()
.
False
(
platforms
[
service
.
PlatformAnthropic
],
"不应包含 anthropic 账户"
)
// 验证账户 ID 匹配
ids
:=
make
(
map
[
int64
]
bool
)
for
_
,
acc
:=
range
accounts
{
ids
[
acc
.
ID
]
=
true
}
s
.
Require
()
.
True
(
ids
[
geminiAcc
.
ID
])
s
.
Require
()
.
True
(
ids
[
antigravityAcc
.
ID
])
}
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
func
(
s
*
GatewayRoutingSuite
)
TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding
()
{
// 创建 gemini 分组
group
:=
mustCreateGroup
(
s
.
T
(),
s
.
db
,
&
groupModel
{
Name
:
"gemini-group"
,
Platform
:
service
.
PlatformGemini
,
Status
:
service
.
StatusActive
,
})
// 创建账户
boundAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"bound-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
unboundAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"unbound-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
// 只绑定一个账户到分组
mustBindAccountToGroup
(
s
.
T
(),
s
.
db
,
boundAcc
.
ID
,
group
.
ID
,
1
)
// 查询分组内的账户
accounts
,
err
:=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatforms
(
s
.
ctx
,
group
.
ID
,
[]
string
{
service
.
PlatformGemini
,
service
.
PlatformAntigravity
,
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
1
,
"应只返回绑定到分组的账户"
)
s
.
Require
()
.
Equal
(
boundAcc
.
ID
,
accounts
[
0
]
.
ID
)
// 确认未绑定的账户不在结果中
for
_
,
acc
:=
range
accounts
{
s
.
Require
()
.
NotEqual
(
unboundAcc
.
ID
,
acc
.
ID
,
"不应包含未绑定的账户"
)
}
}
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
func
(
s
*
GatewayRoutingSuite
)
TestListSchedulableByPlatform_Antigravity
()
{
// 创建多种平台账户
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"gemini-1"
,
Platform
:
service
.
PlatformGemini
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
antigravity
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"antigravity-1"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
// 只查询 antigravity 平台
accounts
,
err
:=
s
.
accountRepo
.
ListSchedulableByPlatform
(
s
.
ctx
,
service
.
PlatformAntigravity
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
1
)
s
.
Require
()
.
Equal
(
antigravity
.
ID
,
accounts
[
0
]
.
ID
)
s
.
Require
()
.
Equal
(
service
.
PlatformAntigravity
,
accounts
[
0
]
.
Platform
)
}
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
func
(
s
*
GatewayRoutingSuite
)
TestSchedulableFilter_ExcludesInactive
()
{
// 创建可调度账户
activeAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"active-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true)
inactiveAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"inactive-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
})
s
.
Require
()
.
NoError
(
s
.
db
.
Model
(
&
accountModel
{})
.
Where
(
"id = ?"
,
inactiveAcc
.
ID
)
.
Update
(
"schedulable"
,
false
)
.
Error
)
// 创建错误状态账户
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"error-antigravity"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusError
,
Schedulable
:
true
,
})
accounts
,
err
:=
s
.
accountRepo
.
ListSchedulableByPlatform
(
s
.
ctx
,
service
.
PlatformAntigravity
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
1
,
"应只返回可调度的 active 账户"
)
s
.
Require
()
.
Equal
(
activeAcc
.
ID
,
accounts
[
0
]
.
ID
)
}
// TestPlatformRoutingDecision 验证平台路由决策
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
func
(
s
*
GatewayRoutingSuite
)
TestPlatformRoutingDecision
()
{
// 创建两种平台的账户
geminiAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"gemini-route-test"
,
Platform
:
service
.
PlatformGemini
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
antigravityAcc
:=
mustCreateAccount
(
s
.
T
(),
s
.
db
,
&
accountModel
{
Name
:
"antigravity-route-test"
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
})
tests
:=
[]
struct
{
name
string
accountID
int64
expectedService
string
}{
{
name
:
"Gemini账户路由到ForwardNative"
,
accountID
:
geminiAcc
.
ID
,
expectedService
:
"GeminiMessagesCompatService.ForwardNative"
,
},
{
name
:
"Antigravity账户路由到ForwardGemini"
,
accountID
:
antigravityAcc
.
ID
,
expectedService
:
"AntigravityGatewayService.ForwardGemini"
,
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
// 从数据库获取账户
account
,
err
:=
s
.
accountRepo
.
GetByID
(
s
.
ctx
,
tt
.
accountID
)
s
.
Require
()
.
NoError
(
err
)
// 模拟 Handler 层的路由决策
var
routedService
string
if
account
.
Platform
==
service
.
PlatformAntigravity
{
routedService
=
"AntigravityGatewayService.ForwardGemini"
}
else
{
routedService
=
"GeminiMessagesCompatService.ForwardNative"
}
s
.
Require
()
.
Equal
(
tt
.
expectedService
,
routedService
)
})
}
}
backend/internal/server/middleware/middleware.go
View file @
e83f0ee3
package
middleware
import
"github.com/gin-gonic/gin"
import
(
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
)
// ContextKey 定义上下文键类型
type
ContextKey
string
...
...
@@ -14,8 +19,39 @@ const (
ContextKeyApiKey
ContextKey
=
"api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription
ContextKey
=
"subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
ContextKeyForcePlatform
ContextKey
=
"force_platform"
)
// ForcePlatform 返回设置强制平台的中间件
// 同时设置 request.Context(供 Service 使用)和 gin.Context(供 Handler 快速检查)
func
ForcePlatform
(
platform
string
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
// 设置到 request.Context,使用 ctxkey.ForcePlatform 供 Service 层读取
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ForcePlatform
,
platform
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
// 同时设置到 gin.Context,供 Handler 快速检查
c
.
Set
(
string
(
ContextKeyForcePlatform
),
platform
)
c
.
Next
()
}
}
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
func
HasForcePlatform
(
c
*
gin
.
Context
)
bool
{
_
,
exists
:=
c
.
Get
(
string
(
ContextKeyForcePlatform
))
return
exists
}
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
func
GetForcePlatformFromContext
(
c
*
gin
.
Context
)
(
string
,
bool
)
{
value
,
exists
:=
c
.
Get
(
string
(
ContextKeyForcePlatform
))
if
!
exists
{
return
""
,
false
}
platform
,
ok
:=
value
.
(
string
)
return
platform
,
ok
}
// ErrorResponse 标准错误响应结构
type
ErrorResponse
struct
{
Code
string
`json:"code"`
...
...
backend/internal/server/routes/admin.go
View file @
e83f0ee3
...
...
@@ -34,6 +34,9 @@ func RegisterAdminRoutes(
// Gemini OAuth
registerGeminiOAuthRoutes
(
admin
,
h
)
// Antigravity OAuth
registerAntigravityOAuthRoutes
(
admin
,
h
)
// 代理管理
registerProxyRoutes
(
admin
,
h
)
...
...
@@ -148,6 +151,14 @@ func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func
registerAntigravityOAuthRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
antigravity
:=
admin
.
Group
(
"/antigravity"
)
{
antigravity
.
POST
(
"/oauth/auth-url"
,
h
.
Admin
.
AntigravityOAuth
.
GenerateAuthURL
)
antigravity
.
POST
(
"/oauth/exchange-code"
,
h
.
Admin
.
AntigravityOAuth
.
ExchangeCode
)
}
}
func
registerProxyRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
proxies
:=
admin
.
Group
(
"/proxies"
)
{
...
...
backend/internal/server/routes/gateway.go
View file @
e83f0ee3
...
...
@@ -42,4 +42,24 @@ func RegisterGatewayRoutes(
// OpenAI Responses API(不带v1前缀的别名)
r
.
POST
(
"/responses"
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
OpenAIGateway
.
Responses
)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1
:=
r
.
Group
(
"/antigravity/v1"
)
antigravityV1
.
Use
(
middleware
.
ForcePlatform
(
service
.
PlatformAntigravity
))
antigravityV1
.
Use
(
gin
.
HandlerFunc
(
apiKeyAuth
))
{
antigravityV1
.
POST
(
"/messages"
,
h
.
Gateway
.
Messages
)
antigravityV1
.
POST
(
"/messages/count_tokens"
,
h
.
Gateway
.
CountTokens
)
antigravityV1
.
GET
(
"/models"
,
h
.
Gateway
.
Models
)
antigravityV1
.
GET
(
"/usage"
,
h
.
Gateway
.
Usage
)
}
antigravityV1Beta
:=
r
.
Group
(
"/antigravity/v1beta"
)
antigravityV1Beta
.
Use
(
middleware
.
ForcePlatform
(
service
.
PlatformAntigravity
))
antigravityV1Beta
.
Use
(
middleware
.
ApiKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
subscriptionService
,
cfg
))
{
antigravityV1Beta
.
GET
(
"/models"
,
h
.
Gateway
.
GeminiV1BetaListModels
)
antigravityV1Beta
.
GET
(
"/models/:model"
,
h
.
Gateway
.
GeminiV1BetaGetModel
)
antigravityV1Beta
.
POST
(
"/models/*modelAction"
,
h
.
Gateway
.
GeminiV1BetaModels
)
}
}
backend/internal/service/account.go
View file @
e83f0ee3
...
...
@@ -346,3 +346,20 @@ func (a *Account) IsOpenAITokenExpired() bool {
}
return
time
.
Now
()
.
Add
(
60
*
time
.
Second
)
.
After
(
*
expiresAt
)
}
// IsMixedSchedulingEnabled 检查 antigravity 账户是否启用混合调度
// 启用后可参与 anthropic/gemini 分组的账户调度
func
(
a
*
Account
)
IsMixedSchedulingEnabled
()
bool
{
if
a
.
Platform
!=
PlatformAntigravity
{
return
false
}
if
a
.
Extra
==
nil
{
return
false
}
if
v
,
ok
:=
a
.
Extra
[
"mixed_scheduling"
];
ok
{
if
enabled
,
ok
:=
v
.
(
bool
);
ok
{
return
enabled
}
}
return
false
}
backend/internal/service/account_service.go
View file @
e83f0ee3
...
...
@@ -40,6 +40,8 @@ type AccountRepository interface {
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
...
...
backend/internal/service/antigravity_gateway_service.go
0 → 100644
View file @
e83f0ee3
package
service
import
(
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const
(
antigravityStickySessionTTL
=
time
.
Hour
antigravityMaxRetries
=
5
antigravityRetryBaseDelay
=
1
*
time
.
Second
antigravityRetryMaxDelay
=
16
*
time
.
Second
)
// Antigravity 直接支持的模型
var
antigravitySupportedModels
=
map
[
string
]
bool
{
"claude-opus-4-5-thinking"
:
true
,
"claude-sonnet-4-5"
:
true
,
"claude-sonnet-4-5-thinking"
:
true
,
"gemini-2.5-flash"
:
true
,
"gemini-2.5-flash-lite"
:
true
,
"gemini-2.5-flash-thinking"
:
true
,
"gemini-3-flash"
:
true
,
"gemini-3-pro-low"
:
true
,
"gemini-3-pro-high"
:
true
,
"gemini-3-pro-preview"
:
true
,
"gemini-3-pro-image"
:
true
,
}
// Antigravity 系统默认模型映射表(不支持 → 支持)
var
antigravityModelMapping
=
map
[
string
]
string
{
"claude-3-5-sonnet-20241022"
:
"claude-sonnet-4-5"
,
"claude-3-5-sonnet-20240620"
:
"claude-sonnet-4-5"
,
"claude-sonnet-4-5-20250929"
:
"claude-sonnet-4-5-thinking"
,
"claude-opus-4"
:
"claude-opus-4-5-thinking"
,
"claude-opus-4-5-20251101"
:
"claude-opus-4-5-thinking"
,
"claude-haiku-4"
:
"gemini-3-flash"
,
"claude-haiku-4-5"
:
"gemini-3-flash"
,
"claude-3-haiku-20240307"
:
"gemini-3-flash"
,
"claude-haiku-4-5-20251001"
:
"gemini-3-flash"
,
// 生图模型:官方名 → Antigravity 内部名
"gemini-3-pro-image-preview"
:
"gemini-3-pro-image"
,
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
type
AntigravityGatewayService
struct
{
accountRepo
AccountRepository
tokenProvider
*
AntigravityTokenProvider
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
}
func
NewAntigravityGatewayService
(
accountRepo
AccountRepository
,
_
GatewayCache
,
tokenProvider
*
AntigravityTokenProvider
,
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
)
*
AntigravityGatewayService
{
return
&
AntigravityGatewayService
{
accountRepo
:
accountRepo
,
tokenProvider
:
tokenProvider
,
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
}
}
// GetTokenProvider 返回 token provider
func
(
s
*
AntigravityGatewayService
)
GetTokenProvider
()
*
AntigravityTokenProvider
{
return
s
.
tokenProvider
}
// getMappedModel 获取映射后的模型名
func
(
s
*
AntigravityGatewayService
)
getMappedModel
(
account
*
Account
,
requestedModel
string
)
string
{
// 1. 优先使用账户级映射(复用现有方法)
if
mapped
:=
account
.
GetMappedModel
(
requestedModel
);
mapped
!=
requestedModel
{
return
mapped
}
// 2. 系统默认映射
if
mapped
,
ok
:=
antigravityModelMapping
[
requestedModel
];
ok
{
return
mapped
}
// 3. Gemini 模型透传
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
requestedModel
}
// 4. Claude 前缀透传直接支持的模型
if
antigravitySupportedModels
[
requestedModel
]
{
return
requestedModel
}
// 5. 默认值
return
"claude-sonnet-4-5"
}
// IsModelSupported 检查模型是否被支持
func
(
s
*
AntigravityGatewayService
)
IsModelSupported
(
requestedModel
string
)
bool
{
// 直接支持的模型
if
antigravitySupportedModels
[
requestedModel
]
{
return
true
}
// 可映射的模型
if
_
,
ok
:=
antigravityModelMapping
[
requestedModel
];
ok
{
return
true
}
// Gemini 前缀透传
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
true
}
// Claude 模型支持(通过默认映射)
if
strings
.
HasPrefix
(
requestedModel
,
"claude-"
)
{
return
true
}
return
false
}
// wrapV1InternalRequest 包装请求为 v1internal 格式
func
(
s
*
AntigravityGatewayService
)
wrapV1InternalRequest
(
projectID
,
model
string
,
originalBody
[]
byte
)
([]
byte
,
error
)
{
var
request
any
if
err
:=
json
.
Unmarshal
(
originalBody
,
&
request
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"解析请求体失败: %w"
,
err
)
}
wrapped
:=
map
[
string
]
any
{
"project"
:
projectID
,
"requestId"
:
"agent-"
+
uuid
.
New
()
.
String
(),
"userAgent"
:
"sub2api"
,
"requestType"
:
"agent"
,
"model"
:
model
,
"request"
:
request
,
}
return
json
.
Marshal
(
wrapped
)
}
// unwrapV1InternalResponse 解包 v1internal 响应
func
(
s
*
AntigravityGatewayService
)
unwrapV1InternalResponse
(
body
[]
byte
)
([]
byte
,
error
)
{
var
outer
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
outer
);
err
!=
nil
{
return
nil
,
err
}
if
resp
,
ok
:=
outer
[
"response"
];
ok
{
return
json
.
Marshal
(
resp
)
}
return
body
,
nil
}
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
func
(
s
*
AntigravityGatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
// 解析 Claude 请求
var
claudeReq
antigravity
.
ClaudeRequest
if
err
:=
json
.
Unmarshal
(
body
,
&
claudeReq
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse claude request: %w"
,
err
)
}
if
strings
.
TrimSpace
(
claudeReq
.
Model
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing model"
)
}
originalModel
:=
claudeReq
.
Model
mappedModel
:=
s
.
getMappedModel
(
account
,
claudeReq
.
Model
)
if
mappedModel
!=
claudeReq
.
Model
{
log
.
Printf
(
"Antigravity model mapping: %s -> %s (account: %s)"
,
claudeReq
.
Model
,
mappedModel
,
account
.
Name
)
}
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
return
nil
,
errors
.
New
(
"antigravity token provider not configured"
)
}
accessToken
,
err
:=
s
.
tokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"获取 access_token 失败: %w"
,
err
)
}
// 获取 project_id
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
==
""
{
return
nil
,
errors
.
New
(
"project_id not found in credentials"
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 转换 Claude 请求为 Gemini 格式
geminiBody
,
err
:=
antigravity
.
TransformClaudeToGemini
(
&
claudeReq
,
projectID
,
mappedModel
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
}
// 构建上游 URL
action
:=
"generateContent"
if
claudeReq
.
Stream
{
action
=
"streamGenerateContent"
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
antigravity
.
BaseURL
,
action
)
if
claudeReq
.
Stream
{
fullURL
+=
"?alt=sse"
}
// 重试循环
var
resp
*
http
.
Response
for
attempt
:=
1
;
attempt
<=
antigravityMaxRetries
;
attempt
++
{
upstreamReq
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fullURL
,
bytes
.
NewReader
(
geminiBody
))
if
err
!=
nil
{
return
nil
,
err
}
upstreamReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
upstreamReq
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
upstreamReq
.
Header
.
Set
(
"User-Agent"
,
antigravity
.
UserAgent
)
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
)
if
err
!=
nil
{
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"Antigravity account %d: upstream request failed, retry %d/%d: %v"
,
account
.
ID
,
attempt
,
antigravityMaxRetries
,
err
)
sleepAntigravityBackoff
(
attempt
)
continue
}
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries"
)
}
if
resp
.
StatusCode
>=
400
&&
s
.
shouldRetryUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"Antigravity account %d: upstream status %d, retry %d/%d"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
antigravityMaxRetries
)
sleepAntigravityBackoff
(
attempt
)
continue
}
// 所有重试都失败,标记限流状态
if
resp
.
StatusCode
==
429
{
s
.
handleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
}
// 最后一次尝试也失败
resp
=
&
http
.
Response
{
StatusCode
:
resp
.
StatusCode
,
Header
:
resp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
break
}
break
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
s
.
handleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
return
nil
,
s
.
writeMappedClaudeError
(
c
,
resp
.
StatusCode
,
respBody
)
}
requestID
:=
resp
.
Header
.
Get
(
"x-request-id"
)
if
requestID
!=
""
{
c
.
Header
(
"x-request-id"
,
requestID
)
}
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
if
claudeReq
.
Stream
{
streamRes
,
err
:=
s
.
handleClaudeStreamingResponse
(
c
,
resp
,
startTime
,
originalModel
)
if
err
!=
nil
{
return
nil
,
err
}
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
}
else
{
usage
,
err
=
s
.
handleClaudeNonStreamingResponse
(
c
,
resp
,
originalModel
)
if
err
!=
nil
{
return
nil
,
err
}
}
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
Model
:
originalModel
,
// 使用原始模型用于计费和日志
Stream
:
claudeReq
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
},
nil
}
// ForwardGemini 转发 Gemini 协议请求
func
(
s
*
AntigravityGatewayService
)
ForwardGemini
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
string
,
action
string
,
stream
bool
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
if
strings
.
TrimSpace
(
originalModel
)
==
""
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Missing model in URL"
)
}
if
strings
.
TrimSpace
(
action
)
==
""
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Missing action in URL"
)
}
if
len
(
body
)
==
0
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Request body is empty"
)
}
switch
action
{
case
"generateContent"
,
"streamGenerateContent"
,
"countTokens"
:
// ok
default
:
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusNotFound
,
"Unsupported action: "
+
action
)
}
mappedModel
:=
s
.
getMappedModel
(
account
,
originalModel
)
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
return
nil
,
errors
.
New
(
"antigravity token provider not configured"
)
}
accessToken
,
err
:=
s
.
tokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"获取 access_token 失败: %w"
,
err
)
}
// 获取 project_id
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
==
""
{
return
nil
,
errors
.
New
(
"project_id not found in credentials"
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 包装请求
wrappedBody
,
err
:=
s
.
wrapV1InternalRequest
(
projectID
,
mappedModel
,
body
)
if
err
!=
nil
{
return
nil
,
err
}
// 构建上游 URL
upstreamAction
:=
action
if
action
==
"generateContent"
&&
stream
{
upstreamAction
=
"streamGenerateContent"
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
antigravity
.
BaseURL
,
upstreamAction
)
if
stream
||
upstreamAction
==
"streamGenerateContent"
{
fullURL
+=
"?alt=sse"
}
// 重试循环
var
resp
*
http
.
Response
for
attempt
:=
1
;
attempt
<=
antigravityMaxRetries
;
attempt
++
{
upstreamReq
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fullURL
,
bytes
.
NewReader
(
wrappedBody
))
if
err
!=
nil
{
return
nil
,
err
}
upstreamReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
upstreamReq
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
upstreamReq
.
Header
.
Set
(
"User-Agent"
,
antigravity
.
UserAgent
)
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
)
if
err
!=
nil
{
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"Antigravity account %d: upstream request failed, retry %d/%d: %v"
,
account
.
ID
,
attempt
,
antigravityMaxRetries
,
err
)
sleepAntigravityBackoff
(
attempt
)
continue
}
if
action
==
"countTokens"
{
estimated
:=
estimateGeminiCountTokens
(
body
)
c
.
JSON
(
http
.
StatusOK
,
map
[
string
]
any
{
"totalTokens"
:
estimated
})
return
&
ForwardResult
{
RequestID
:
""
,
Usage
:
ClaudeUsage
{},
Model
:
originalModel
,
Stream
:
false
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
nil
,
},
nil
}
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Upstream request failed after retries"
)
}
if
resp
.
StatusCode
>=
400
&&
s
.
shouldRetryUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"Antigravity account %d: upstream status %d, retry %d/%d"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
antigravityMaxRetries
)
sleepAntigravityBackoff
(
attempt
)
continue
}
// 所有重试都失败,标记限流状态
if
resp
.
StatusCode
==
429
{
s
.
handleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
}
if
action
==
"countTokens"
{
estimated
:=
estimateGeminiCountTokens
(
body
)
c
.
JSON
(
http
.
StatusOK
,
map
[
string
]
any
{
"totalTokens"
:
estimated
})
return
&
ForwardResult
{
RequestID
:
""
,
Usage
:
ClaudeUsage
{},
Model
:
originalModel
,
Stream
:
false
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
nil
,
},
nil
}
resp
=
&
http
.
Response
{
StatusCode
:
resp
.
StatusCode
,
Header
:
resp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
break
}
break
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
requestID
:=
resp
.
Header
.
Get
(
"x-request-id"
)
if
requestID
!=
""
{
c
.
Header
(
"x-request-id"
,
requestID
)
}
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
s
.
handleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
if
action
==
"countTokens"
{
estimated
:=
estimateGeminiCountTokens
(
body
)
c
.
JSON
(
http
.
StatusOK
,
map
[
string
]
any
{
"totalTokens"
:
estimated
})
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
ClaudeUsage
{},
Model
:
originalModel
,
Stream
:
false
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
nil
,
},
nil
}
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
// 解包并返回错误
unwrapped
,
_
:=
s
.
unwrapV1InternalResponse
(
respBody
)
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"application/json"
}
c
.
Data
(
resp
.
StatusCode
,
contentType
,
unwrapped
)
return
nil
,
fmt
.
Errorf
(
"antigravity upstream error: %d"
,
resp
.
StatusCode
)
}
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
if
stream
||
upstreamAction
==
"streamGenerateContent"
{
streamRes
,
err
:=
s
.
handleGeminiStreamingResponse
(
c
,
resp
,
startTime
)
if
err
!=
nil
{
return
nil
,
err
}
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
}
else
{
usageResp
,
err
:=
s
.
handleGeminiNonStreamingResponse
(
c
,
resp
)
if
err
!=
nil
{
return
nil
,
err
}
usage
=
usageResp
}
if
usage
==
nil
{
usage
=
&
ClaudeUsage
{}
}
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
Model
:
originalModel
,
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
},
nil
}
func
(
s
*
AntigravityGatewayService
)
shouldRetryUpstreamError
(
statusCode
int
)
bool
{
switch
statusCode
{
case
429
,
500
,
502
,
503
,
504
,
529
:
return
true
default
:
return
false
}
}
func
(
s
*
AntigravityGatewayService
)
shouldFailoverUpstreamError
(
statusCode
int
)
bool
{
switch
statusCode
{
case
401
,
403
,
429
,
529
:
return
true
default
:
return
statusCode
>=
500
}
}
func
sleepAntigravityBackoff
(
attempt
int
)
{
sleepGeminiBackoff
(
attempt
)
// 复用 Gemini 的退避逻辑
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
)
{
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if
statusCode
==
429
{
resetAt
:=
ParseGeminiRateLimitResetTime
(
body
)
if
resetAt
==
nil
{
// 解析失败:Gemini 有重试时间用 5 分钟,Claude 没有用 1 分钟
defaultDur
:=
1
*
time
.
Minute
if
bytes
.
Contains
(
body
,
[]
byte
(
"Please retry in"
))
||
bytes
.
Contains
(
body
,
[]
byte
(
"retryDelay"
))
{
defaultDur
=
5
*
time
.
Minute
}
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
ra
)
return
}
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
time
.
Unix
(
*
resetAt
,
0
))
return
}
// 其他错误码继续使用 rateLimitService
if
s
.
rateLimitService
==
nil
{
return
}
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
statusCode
,
headers
,
body
)
}
type
antigravityStreamResult
struct
{
usage
*
ClaudeUsage
firstTokenMs
*
int
}
func
(
s
*
AntigravityGatewayService
)
handleGeminiStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
(
*
antigravityStreamResult
,
error
)
{
c
.
Status
(
resp
.
StatusCode
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"text/event-stream; charset=utf-8"
}
c
.
Header
(
"Content-Type"
,
contentType
)
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
)
if
!
ok
{
return
nil
,
errors
.
New
(
"streaming not supported"
)
}
reader
:=
bufio
.
NewReader
(
resp
.
Body
)
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
if
len
(
line
)
>
0
{
trimmed
:=
strings
.
TrimRight
(
line
,
"
\r\n
"
)
if
strings
.
HasPrefix
(
trimmed
,
"data:"
)
{
payload
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"data:"
))
if
payload
==
""
||
payload
==
"[DONE]"
{
_
,
_
=
io
.
WriteString
(
c
.
Writer
,
line
)
flusher
.
Flush
()
}
else
{
// 解包 v1internal 响应
inner
,
parseErr
:=
s
.
unwrapV1InternalResponse
([]
byte
(
payload
))
if
parseErr
==
nil
&&
inner
!=
nil
{
payload
=
string
(
inner
)
}
// 解析 usage
var
parsed
map
[
string
]
any
if
json
.
Unmarshal
(
inner
,
&
parsed
)
==
nil
{
if
u
:=
extractGeminiUsage
(
parsed
);
u
!=
nil
{
usage
=
u
}
}
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
payload
)
flusher
.
Flush
()
}
}
else
{
_
,
_
=
io
.
WriteString
(
c
.
Writer
,
line
)
flusher
.
Flush
()
}
}
if
errors
.
Is
(
err
,
io
.
EOF
)
{
break
}
if
err
!=
nil
{
return
nil
,
err
}
}
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
func
(
s
*
AntigravityGatewayService
)
handleGeminiNonStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
)
(
*
ClaudeUsage
,
error
)
{
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
err
}
// 解包 v1internal 响应
unwrapped
,
_
:=
s
.
unwrapV1InternalResponse
(
respBody
)
var
parsed
map
[
string
]
any
if
json
.
Unmarshal
(
unwrapped
,
&
parsed
)
==
nil
{
if
u
:=
extractGeminiUsage
(
parsed
);
u
!=
nil
{
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
unwrapped
)
return
u
,
nil
}
}
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
unwrapped
)
return
&
ClaudeUsage
{},
nil
}
func
(
s
*
AntigravityGatewayService
)
writeClaudeError
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
)
error
{
c
.
JSON
(
status
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
message
},
})
return
fmt
.
Errorf
(
"%s"
,
message
)
}
func
(
s
*
AntigravityGatewayService
)
writeMappedClaudeError
(
c
*
gin
.
Context
,
upstreamStatus
int
,
body
[]
byte
)
error
{
// 记录上游错误详情便于调试
log
.
Printf
(
"Antigravity upstream error %d: %s"
,
upstreamStatus
,
string
(
body
))
var
statusCode
int
var
errType
,
errMsg
string
switch
upstreamStatus
{
case
400
:
statusCode
=
http
.
StatusBadRequest
errType
=
"invalid_request_error"
errMsg
=
"Invalid request"
case
401
:
statusCode
=
http
.
StatusBadGateway
errType
=
"authentication_error"
errMsg
=
"Upstream authentication failed"
case
403
:
statusCode
=
http
.
StatusBadGateway
errType
=
"permission_error"
errMsg
=
"Upstream access forbidden"
case
429
:
statusCode
=
http
.
StatusTooManyRequests
errType
=
"rate_limit_error"
errMsg
=
"Upstream rate limit exceeded"
case
529
:
statusCode
=
http
.
StatusServiceUnavailable
errType
=
"overloaded_error"
errMsg
=
"Upstream service overloaded"
default
:
statusCode
=
http
.
StatusBadGateway
errType
=
"upstream_error"
errMsg
=
"Upstream request failed"
}
c
.
JSON
(
statusCode
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
errMsg
},
})
return
fmt
.
Errorf
(
"upstream error: %d"
,
upstreamStatus
)
}
func
(
s
*
AntigravityGatewayService
)
writeGoogleError
(
c
*
gin
.
Context
,
status
int
,
message
string
)
error
{
statusStr
:=
"UNKNOWN"
switch
status
{
case
400
:
statusStr
=
"INVALID_ARGUMENT"
case
404
:
statusStr
=
"NOT_FOUND"
case
429
:
statusStr
=
"RESOURCE_EXHAUSTED"
case
500
:
statusStr
=
"INTERNAL"
case
502
,
503
:
statusStr
=
"UNAVAILABLE"
}
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
gin
.
H
{
"code"
:
status
,
"message"
:
message
,
"status"
:
statusStr
,
},
})
return
fmt
.
Errorf
(
"%s"
,
message
)
}
// handleClaudeNonStreamingResponse 处理 Claude 非流式响应(Gemini → Claude 转换)
func
(
s
*
AntigravityGatewayService
)
handleClaudeNonStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
originalModel
string
)
(
*
ClaudeUsage
,
error
)
{
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
8
<<
20
))
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Failed to read upstream response"
)
}
// 转换 Gemini 响应为 Claude 格式
claudeResp
,
agUsage
,
err
:=
antigravity
.
TransformGeminiToClaude
(
body
,
originalModel
)
if
err
!=
nil
{
log
.
Printf
(
"Transform Gemini to Claude failed: %v, body: %s"
,
err
,
string
(
body
))
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Failed to parse upstream response"
)
}
c
.
Data
(
http
.
StatusOK
,
"application/json"
,
claudeResp
)
// 转换为 service.ClaudeUsage
usage
:=
&
ClaudeUsage
{
InputTokens
:
agUsage
.
InputTokens
,
OutputTokens
:
agUsage
.
OutputTokens
,
CacheCreationInputTokens
:
agUsage
.
CacheCreationInputTokens
,
CacheReadInputTokens
:
agUsage
.
CacheReadInputTokens
,
}
return
usage
,
nil
}
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
func
(
s
*
AntigravityGatewayService
)
handleClaudeStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
,
originalModel
string
)
(
*
antigravityStreamResult
,
error
)
{
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
c
.
Status
(
http
.
StatusOK
)
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
)
if
!
ok
{
return
nil
,
errors
.
New
(
"streaming not supported"
)
}
processor
:=
antigravity
.
NewStreamingProcessor
(
originalModel
)
var
firstTokenMs
*
int
reader
:=
bufio
.
NewReader
(
resp
.
Body
)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage
:=
func
(
agUsage
*
antigravity
.
ClaudeUsage
)
*
ClaudeUsage
{
if
agUsage
==
nil
{
return
&
ClaudeUsage
{}
}
return
&
ClaudeUsage
{
InputTokens
:
agUsage
.
InputTokens
,
OutputTokens
:
agUsage
.
OutputTokens
,
CacheCreationInputTokens
:
agUsage
.
CacheCreationInputTokens
,
CacheReadInputTokens
:
agUsage
.
CacheReadInputTokens
,
}
}
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
io
.
EOF
)
{
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
}
if
len
(
line
)
>
0
{
// 处理 SSE 行,转换为 Claude 格式
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
line
,
"
\r\n
"
))
if
len
(
claudeEvents
)
>
0
{
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
if
_
,
writeErr
:=
c
.
Writer
.
Write
(
claudeEvents
);
writeErr
!=
nil
{
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
writeErr
}
flusher
.
Flush
()
}
}
if
errors
.
Is
(
err
,
io
.
EOF
)
{
break
}
}
// 发送结束事件
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
flusher
.
Flush
()
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
nil
}
backend/internal/service/antigravity_model_mapping_test.go
0 → 100644
View file @
e83f0ee3
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestIsAntigravityModelSupported
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
model
string
expected
bool
}{
// 直接支持的模型
{
"直接支持 - claude-sonnet-4-5"
,
"claude-sonnet-4-5"
,
true
},
{
"直接支持 - claude-opus-4-5-thinking"
,
"claude-opus-4-5-thinking"
,
true
},
{
"直接支持 - claude-sonnet-4-5-thinking"
,
"claude-sonnet-4-5-thinking"
,
true
},
{
"直接支持 - gemini-2.5-flash"
,
"gemini-2.5-flash"
,
true
},
{
"直接支持 - gemini-2.5-flash-lite"
,
"gemini-2.5-flash-lite"
,
true
},
{
"直接支持 - gemini-3-pro-high"
,
"gemini-3-pro-high"
,
true
},
// 可映射的模型
{
"可映射 - claude-3-5-sonnet-20241022"
,
"claude-3-5-sonnet-20241022"
,
true
},
{
"可映射 - claude-3-5-sonnet-20240620"
,
"claude-3-5-sonnet-20240620"
,
true
},
{
"可映射 - claude-opus-4"
,
"claude-opus-4"
,
true
},
{
"可映射 - claude-haiku-4"
,
"claude-haiku-4"
,
true
},
{
"可映射 - claude-3-haiku-20240307"
,
"claude-3-haiku-20240307"
,
true
},
// Gemini 前缀透传
{
"Gemini前缀 - gemini-1.5-pro"
,
"gemini-1.5-pro"
,
true
},
{
"Gemini前缀 - gemini-unknown-model"
,
"gemini-unknown-model"
,
true
},
{
"Gemini前缀 - gemini-future-version"
,
"gemini-future-version"
,
true
},
// Claude 前缀兜底
{
"Claude前缀 - claude-unknown-model"
,
"claude-unknown-model"
,
true
},
{
"Claude前缀 - claude-3-opus-20240229"
,
"claude-3-opus-20240229"
,
true
},
{
"Claude前缀 - claude-future-version"
,
"claude-future-version"
,
true
},
// 不支持的模型
{
"不支持 - gpt-4"
,
"gpt-4"
,
false
},
{
"不支持 - gpt-4o"
,
"gpt-4o"
,
false
},
{
"不支持 - llama-3"
,
"llama-3"
,
false
},
{
"不支持 - mistral-7b"
,
"mistral-7b"
,
false
},
{
"不支持 - 空字符串"
,
""
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
IsAntigravityModelSupported
(
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
,
"model: %s"
,
tt
.
model
)
})
}
}
func
TestAntigravityGatewayService_GetMappedModel
(
t
*
testing
.
T
)
{
svc
:=
&
AntigravityGatewayService
{}
tests
:=
[]
struct
{
name
string
requestedModel
string
accountMapping
map
[
string
]
string
expected
string
}{
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any)
{
name
:
"账户映射优先"
,
requestedModel
:
"claude-3-5-sonnet-20241022"
,
accountMapping
:
map
[
string
]
string
{
"claude-3-5-sonnet-20241022"
:
"custom-model"
},
expected
:
"custom-model"
,
},
{
name
:
"账户映射覆盖系统映射"
,
requestedModel
:
"claude-opus-4"
,
accountMapping
:
map
[
string
]
string
{
"claude-opus-4"
:
"my-opus"
},
expected
:
"my-opus"
,
},
// 2. 系统默认映射
{
name
:
"系统映射 - claude-3-5-sonnet-20241022"
,
requestedModel
:
"claude-3-5-sonnet-20241022"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"系统映射 - claude-3-5-sonnet-20240620"
,
requestedModel
:
"claude-3-5-sonnet-20240620"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"系统映射 - claude-opus-4"
,
requestedModel
:
"claude-opus-4"
,
accountMapping
:
nil
,
expected
:
"claude-opus-4-5-thinking"
,
},
{
name
:
"系统映射 - claude-opus-4-5-20251101"
,
requestedModel
:
"claude-opus-4-5-20251101"
,
accountMapping
:
nil
,
expected
:
"claude-opus-4-5-thinking"
,
},
{
name
:
"系统映射 - claude-haiku-4 → gemini-3-flash"
,
requestedModel
:
"claude-haiku-4"
,
accountMapping
:
nil
,
expected
:
"gemini-3-flash"
,
},
{
name
:
"系统映射 - claude-haiku-4-5 → gemini-3-flash"
,
requestedModel
:
"claude-haiku-4-5"
,
accountMapping
:
nil
,
expected
:
"gemini-3-flash"
,
},
{
name
:
"系统映射 - claude-3-haiku-20240307 → gemini-3-flash"
,
requestedModel
:
"claude-3-haiku-20240307"
,
accountMapping
:
nil
,
expected
:
"gemini-3-flash"
,
},
{
name
:
"系统映射 - claude-haiku-4-5-20251001 → gemini-3-flash"
,
requestedModel
:
"claude-haiku-4-5-20251001"
,
accountMapping
:
nil
,
expected
:
"gemini-3-flash"
,
},
{
name
:
"系统映射 - claude-sonnet-4-5-20250929"
,
requestedModel
:
"claude-sonnet-4-5-20250929"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5-thinking"
,
},
// 3. Gemini 透传
{
name
:
"Gemini透传 - gemini-2.5-flash"
,
requestedModel
:
"gemini-2.5-flash"
,
accountMapping
:
nil
,
expected
:
"gemini-2.5-flash"
,
},
{
name
:
"Gemini透传 - gemini-1.5-pro"
,
requestedModel
:
"gemini-1.5-pro"
,
accountMapping
:
nil
,
expected
:
"gemini-1.5-pro"
,
},
{
name
:
"Gemini透传 - gemini-future-model"
,
requestedModel
:
"gemini-future-model"
,
accountMapping
:
nil
,
expected
:
"gemini-future-model"
,
},
// 4. 直接支持的模型
{
name
:
"直接支持 - claude-sonnet-4-5"
,
requestedModel
:
"claude-sonnet-4-5"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"直接支持 - claude-opus-4-5-thinking"
,
requestedModel
:
"claude-opus-4-5-thinking"
,
accountMapping
:
nil
,
expected
:
"claude-opus-4-5-thinking"
,
},
{
name
:
"直接支持 - claude-sonnet-4-5-thinking"
,
requestedModel
:
"claude-sonnet-4-5-thinking"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5-thinking"
,
},
// 5. 默认值 fallback(未知 claude 模型)
{
name
:
"默认值 - claude-unknown"
,
requestedModel
:
"claude-unknown"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
{
name
:
"默认值 - claude-3-opus-20240229"
,
requestedModel
:
"claude-3-opus-20240229"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAntigravity
,
}
if
tt
.
accountMapping
!=
nil
{
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
mappingAny
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
tt
.
accountMapping
{
mappingAny
[
k
]
=
v
}
account
.
Credentials
=
map
[
string
]
any
{
"model_mapping"
:
mappingAny
,
}
}
got
:=
svc
.
getMappedModel
(
account
,
tt
.
requestedModel
)
require
.
Equal
(
t
,
tt
.
expected
,
got
,
"model: %s"
,
tt
.
requestedModel
)
})
}
}
func
TestAntigravityGatewayService_GetMappedModel_EdgeCases
(
t
*
testing
.
T
)
{
svc
:=
&
AntigravityGatewayService
{}
tests
:=
[]
struct
{
name
string
requestedModel
string
expected
string
}{
// 空字符串回退到默认值
{
"空字符串"
,
""
,
"claude-sonnet-4-5"
},
// 非 claude/gemini 前缀回退到默认值
{
"非claude/gemini前缀 - gpt"
,
"gpt-4"
,
"claude-sonnet-4-5"
},
{
"非claude/gemini前缀 - llama"
,
"llama-3"
,
"claude-sonnet-4-5"
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAntigravity
}
got
:=
svc
.
getMappedModel
(
account
,
tt
.
requestedModel
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
func
TestAntigravityGatewayService_IsModelSupported
(
t
*
testing
.
T
)
{
svc
:=
&
AntigravityGatewayService
{}
tests
:=
[]
struct
{
name
string
model
string
expected
bool
}{
// 直接支持
{
"直接支持 - claude-sonnet-4-5"
,
"claude-sonnet-4-5"
,
true
},
{
"直接支持 - gemini-3-flash"
,
"gemini-3-flash"
,
true
},
// 可映射
{
"可映射 - claude-opus-4"
,
"claude-opus-4"
,
true
},
// 前缀透传
{
"Gemini前缀"
,
"gemini-unknown"
,
true
},
{
"Claude前缀"
,
"claude-unknown"
,
true
},
// 不支持
{
"不支持 - gpt-4"
,
"gpt-4"
,
false
},
{
"不支持 - 空字符串"
,
""
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
IsModelSupported
(
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
backend/internal/service/antigravity_oauth_service.go
0 → 100644
View file @
e83f0ee3
package
service
import
(
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
type
AntigravityOAuthService
struct
{
sessionStore
*
antigravity
.
SessionStore
proxyRepo
ProxyRepository
}
func
NewAntigravityOAuthService
(
proxyRepo
ProxyRepository
)
*
AntigravityOAuthService
{
return
&
AntigravityOAuthService
{
sessionStore
:
antigravity
.
NewSessionStore
(),
proxyRepo
:
proxyRepo
,
}
}
// AntigravityAuthURLResult is the result of generating an authorization URL
type
AntigravityAuthURLResult
struct
{
AuthURL
string
`json:"auth_url"`
SessionID
string
`json:"session_id"`
State
string
`json:"state"`
}
// GenerateAuthURL 生成 Google OAuth 授权链接
func
(
s
*
AntigravityOAuthService
)
GenerateAuthURL
(
ctx
context
.
Context
,
proxyID
*
int64
)
(
*
AntigravityAuthURLResult
,
error
)
{
state
,
err
:=
antigravity
.
GenerateState
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"生成 state 失败: %w"
,
err
)
}
codeVerifier
,
err
:=
antigravity
.
GenerateCodeVerifier
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"生成 code_verifier 失败: %w"
,
err
)
}
sessionID
,
err
:=
antigravity
.
GenerateSessionID
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"生成 session_id 失败: %w"
,
err
)
}
var
proxyURL
string
if
proxyID
!=
nil
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
proxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
}
}
session
:=
&
antigravity
.
OAuthSession
{
State
:
state
,
CodeVerifier
:
codeVerifier
,
ProxyURL
:
proxyURL
,
CreatedAt
:
time
.
Now
(),
}
s
.
sessionStore
.
Set
(
sessionID
,
session
)
codeChallenge
:=
antigravity
.
GenerateCodeChallenge
(
codeVerifier
)
authURL
:=
antigravity
.
BuildAuthorizationURL
(
state
,
codeChallenge
)
return
&
AntigravityAuthURLResult
{
AuthURL
:
authURL
,
SessionID
:
sessionID
,
State
:
state
,
},
nil
}
// AntigravityExchangeCodeInput 交换 code 的输入
type
AntigravityExchangeCodeInput
struct
{
SessionID
string
State
string
Code
string
ProxyID
*
int64
}
// AntigravityTokenInfo token 信息
type
AntigravityTokenInfo
struct
{
AccessToken
string
`json:"access_token"`
RefreshToken
string
`json:"refresh_token"`
ExpiresIn
int64
`json:"expires_in"`
ExpiresAt
int64
`json:"expires_at"`
TokenType
string
`json:"token_type"`
Email
string
`json:"email,omitempty"`
ProjectID
string
`json:"project_id,omitempty"`
}
// ExchangeCode 用 authorization code 交换 token
func
(
s
*
AntigravityOAuthService
)
ExchangeCode
(
ctx
context
.
Context
,
input
*
AntigravityExchangeCodeInput
)
(
*
AntigravityTokenInfo
,
error
)
{
session
,
ok
:=
s
.
sessionStore
.
Get
(
input
.
SessionID
)
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"session 不存在或已过期"
)
}
if
strings
.
TrimSpace
(
input
.
State
)
==
""
||
input
.
State
!=
session
.
State
{
return
nil
,
fmt
.
Errorf
(
"state 无效"
)
}
// 确定代理 URL
proxyURL
:=
session
.
ProxyURL
if
input
.
ProxyID
!=
nil
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
input
.
ProxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
}
}
client
:=
antigravity
.
NewClient
(
proxyURL
)
// 交换 token
tokenResp
,
err
:=
client
.
ExchangeCode
(
ctx
,
input
.
Code
,
session
.
CodeVerifier
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"token 交换失败: %w"
,
err
)
}
// 删除 session
s
.
sessionStore
.
Delete
(
input
.
SessionID
)
// 计算过期时间(减去 5 分钟安全窗口)
expiresAt
:=
time
.
Now
()
.
Unix
()
+
tokenResp
.
ExpiresIn
-
300
result
:=
&
AntigravityTokenInfo
{
AccessToken
:
tokenResp
.
AccessToken
,
RefreshToken
:
tokenResp
.
RefreshToken
,
ExpiresIn
:
tokenResp
.
ExpiresIn
,
ExpiresAt
:
expiresAt
,
TokenType
:
tokenResp
.
TokenType
,
}
// 获取用户信息
userInfo
,
err
:=
client
.
GetUserInfo
(
ctx
,
tokenResp
.
AccessToken
)
if
err
!=
nil
{
fmt
.
Printf
(
"[AntigravityOAuth] 警告: 获取用户信息失败: %v
\n
"
,
err
)
}
else
{
result
.
Email
=
userInfo
.
Email
}
// 获取 project_id
loadResp
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
tokenResp
.
AccessToken
)
if
err
!=
nil
{
fmt
.
Printf
(
"[AntigravityOAuth] 警告: 获取 project_id 失败: %v
\n
"
,
err
)
}
else
if
loadResp
!=
nil
&&
loadResp
.
CloudAICompanionProject
!=
""
{
result
.
ProjectID
=
loadResp
.
CloudAICompanionProject
}
return
result
,
nil
}
// RefreshToken 刷新 token
func
(
s
*
AntigravityOAuthService
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
)
(
*
AntigravityTokenInfo
,
error
)
{
var
lastErr
error
for
attempt
:=
0
;
attempt
<=
3
;
attempt
++
{
if
attempt
>
0
{
backoff
:=
time
.
Duration
(
1
<<
uint
(
attempt
-
1
))
*
time
.
Second
if
backoff
>
30
*
time
.
Second
{
backoff
=
30
*
time
.
Second
}
time
.
Sleep
(
backoff
)
}
client
:=
antigravity
.
NewClient
(
proxyURL
)
tokenResp
,
err
:=
client
.
RefreshToken
(
ctx
,
refreshToken
)
if
err
==
nil
{
expiresAt
:=
time
.
Now
()
.
Unix
()
+
tokenResp
.
ExpiresIn
-
300
return
&
AntigravityTokenInfo
{
AccessToken
:
tokenResp
.
AccessToken
,
RefreshToken
:
tokenResp
.
RefreshToken
,
ExpiresIn
:
tokenResp
.
ExpiresIn
,
ExpiresAt
:
expiresAt
,
TokenType
:
tokenResp
.
TokenType
,
},
nil
}
if
isNonRetryableAntigravityOAuthError
(
err
)
{
return
nil
,
err
}
lastErr
=
err
}
return
nil
,
fmt
.
Errorf
(
"token 刷新失败 (重试后): %w"
,
lastErr
)
}
func
isNonRetryableAntigravityOAuthError
(
err
error
)
bool
{
msg
:=
err
.
Error
()
nonRetryable
:=
[]
string
{
"invalid_grant"
,
"invalid_client"
,
"unauthorized_client"
,
"access_denied"
,
}
for
_
,
needle
:=
range
nonRetryable
{
if
strings
.
Contains
(
msg
,
needle
)
{
return
true
}
}
return
false
}
// RefreshAccountToken 刷新账户的 token
func
(
s
*
AntigravityOAuthService
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
AntigravityTokenInfo
,
error
)
{
if
account
.
Platform
!=
PlatformAntigravity
||
account
.
Type
!=
AccountTypeOAuth
{
return
nil
,
fmt
.
Errorf
(
"非 Antigravity OAuth 账户"
)
}
refreshToken
:=
account
.
GetCredential
(
"refresh_token"
)
if
strings
.
TrimSpace
(
refreshToken
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"无可用的 refresh_token"
)
}
var
proxyURL
string
if
account
.
ProxyID
!=
nil
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
account
.
ProxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
}
}
tokenInfo
,
err
:=
s
.
RefreshToken
(
ctx
,
refreshToken
,
proxyURL
)
if
err
!=
nil
{
return
nil
,
err
}
// 保留原有的 project_id 和 email
existingProjectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
existingProjectID
!=
""
{
tokenInfo
.
ProjectID
=
existingProjectID
}
existingEmail
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"email"
))
if
existingEmail
!=
""
{
tokenInfo
.
Email
=
existingEmail
}
return
tokenInfo
,
nil
}
// BuildAccountCredentials 构建账户凭证
func
(
s
*
AntigravityOAuthService
)
BuildAccountCredentials
(
tokenInfo
*
AntigravityTokenInfo
)
map
[
string
]
any
{
creds
:=
map
[
string
]
any
{
"access_token"
:
tokenInfo
.
AccessToken
,
"expires_at"
:
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
),
}
if
tokenInfo
.
RefreshToken
!=
""
{
creds
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
TokenType
!=
""
{
creds
[
"token_type"
]
=
tokenInfo
.
TokenType
}
if
tokenInfo
.
Email
!=
""
{
creds
[
"email"
]
=
tokenInfo
.
Email
}
if
tokenInfo
.
ProjectID
!=
""
{
creds
[
"project_id"
]
=
tokenInfo
.
ProjectID
}
return
creds
}
// Stop 停止服务
func
(
s
*
AntigravityOAuthService
)
Stop
()
{
s
.
sessionStore
.
Stop
()
}
backend/internal/service/antigravity_quota_refresher.go
0 → 100644
View file @
e83f0ee3
package
service
import
(
"context"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// AntigravityQuotaRefresher 定时刷新 Antigravity 账户的配额信息
type
AntigravityQuotaRefresher
struct
{
accountRepo
AccountRepository
proxyRepo
ProxyRepository
cfg
*
config
.
TokenRefreshConfig
stopCh
chan
struct
{}
wg
sync
.
WaitGroup
}
// NewAntigravityQuotaRefresher 创建配额刷新器
func
NewAntigravityQuotaRefresher
(
accountRepo
AccountRepository
,
proxyRepo
ProxyRepository
,
_
*
AntigravityOAuthService
,
cfg
*
config
.
Config
,
)
*
AntigravityQuotaRefresher
{
return
&
AntigravityQuotaRefresher
{
accountRepo
:
accountRepo
,
proxyRepo
:
proxyRepo
,
cfg
:
&
cfg
.
TokenRefresh
,
stopCh
:
make
(
chan
struct
{}),
}
}
// Start 启动后台配额刷新服务
func
(
r
*
AntigravityQuotaRefresher
)
Start
()
{
if
!
r
.
cfg
.
Enabled
{
log
.
Println
(
"[AntigravityQuota] Service disabled by configuration"
)
return
}
r
.
wg
.
Add
(
1
)
go
r
.
refreshLoop
()
log
.
Printf
(
"[AntigravityQuota] Service started (check every %d minutes)"
,
r
.
cfg
.
CheckIntervalMinutes
)
}
// Stop 停止服务
func
(
r
*
AntigravityQuotaRefresher
)
Stop
()
{
close
(
r
.
stopCh
)
r
.
wg
.
Wait
()
log
.
Println
(
"[AntigravityQuota] Service stopped"
)
}
// refreshLoop 刷新循环
func
(
r
*
AntigravityQuotaRefresher
)
refreshLoop
()
{
defer
r
.
wg
.
Done
()
checkInterval
:=
time
.
Duration
(
r
.
cfg
.
CheckIntervalMinutes
)
*
time
.
Minute
if
checkInterval
<
time
.
Minute
{
checkInterval
=
5
*
time
.
Minute
}
ticker
:=
time
.
NewTicker
(
checkInterval
)
defer
ticker
.
Stop
()
// 启动时立即执行一次
r
.
processRefresh
()
for
{
select
{
case
<-
ticker
.
C
:
r
.
processRefresh
()
case
<-
r
.
stopCh
:
return
}
}
}
// processRefresh 执行一次刷新
func
(
r
*
AntigravityQuotaRefresher
)
processRefresh
()
{
ctx
:=
context
.
Background
()
// 查询所有 active 的账户,然后过滤 antigravity 平台
allAccounts
,
err
:=
r
.
accountRepo
.
ListActive
(
ctx
)
if
err
!=
nil
{
log
.
Printf
(
"[AntigravityQuota] Failed to list accounts: %v"
,
err
)
return
}
// 过滤 antigravity 平台账户
var
accounts
[]
Account
for
_
,
acc
:=
range
allAccounts
{
if
acc
.
Platform
==
PlatformAntigravity
{
accounts
=
append
(
accounts
,
acc
)
}
}
if
len
(
accounts
)
==
0
{
return
}
refreshed
,
failed
:=
0
,
0
for
i
:=
range
accounts
{
account
:=
&
accounts
[
i
]
if
err
:=
r
.
refreshAccountQuota
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"[AntigravityQuota] Account %d (%s) failed: %v"
,
account
.
ID
,
account
.
Name
,
err
)
failed
++
}
else
{
refreshed
++
}
}
log
.
Printf
(
"[AntigravityQuota] Cycle complete: total=%d, refreshed=%d, failed=%d"
,
len
(
accounts
),
refreshed
,
failed
)
}
// refreshAccountQuota 刷新单个账户的配额
func
(
r
*
AntigravityQuotaRefresher
)
refreshAccountQuota
(
ctx
context
.
Context
,
account
*
Account
)
error
{
accessToken
:=
account
.
GetCredential
(
"access_token"
)
projectID
:=
account
.
GetCredential
(
"project_id"
)
if
accessToken
==
""
||
projectID
==
""
{
return
nil
// 没有有效凭证,跳过
}
// token 过期则跳过,由 TokenRefreshService 负责刷新
if
r
.
isTokenExpired
(
account
)
{
return
nil
}
// 获取代理 URL
var
proxyURL
string
if
account
.
ProxyID
!=
nil
{
proxy
,
err
:=
r
.
proxyRepo
.
GetByID
(
ctx
,
*
account
.
ProxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
}
}
client
:=
antigravity
.
NewClient
(
proxyURL
)
// 获取账户类型(tier)
loadResp
,
_
:=
client
.
LoadCodeAssist
(
ctx
,
accessToken
)
if
loadResp
!=
nil
{
r
.
updateAccountTier
(
account
,
loadResp
)
}
// 调用 API 获取配额
modelsResp
,
err
:=
client
.
FetchAvailableModels
(
ctx
,
accessToken
,
projectID
)
if
err
!=
nil
{
return
err
}
// 解析配额数据并更新 extra 字段
r
.
updateAccountQuota
(
account
,
modelsResp
)
// 保存到数据库
return
r
.
accountRepo
.
Update
(
ctx
,
account
)
}
// isTokenExpired 检查 token 是否过期
func
(
r
*
AntigravityQuotaRefresher
)
isTokenExpired
(
account
*
Account
)
bool
{
expiresAt
:=
parseAntigravityExpiresAt
(
account
)
if
expiresAt
==
nil
{
return
false
}
// 提前 5 分钟认为过期
return
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
.
After
(
*
expiresAt
)
}
// updateAccountTier 更新账户类型信息
func
(
r
*
AntigravityQuotaRefresher
)
updateAccountTier
(
account
*
Account
,
loadResp
*
antigravity
.
LoadCodeAssistResponse
)
{
if
account
.
Extra
==
nil
{
account
.
Extra
=
make
(
map
[
string
]
any
)
}
tier
:=
loadResp
.
GetTier
()
if
tier
!=
""
{
account
.
Extra
[
"tier"
]
=
tier
}
// 保存不符合条件的原因(如 INELIGIBLE_ACCOUNT)
if
len
(
loadResp
.
IneligibleTiers
)
>
0
&&
loadResp
.
IneligibleTiers
[
0
]
!=
nil
{
ineligible
:=
loadResp
.
IneligibleTiers
[
0
]
if
ineligible
.
ReasonCode
!=
""
{
account
.
Extra
[
"ineligible_reason_code"
]
=
ineligible
.
ReasonCode
}
if
ineligible
.
ReasonMessage
!=
""
{
account
.
Extra
[
"ineligible_reason_message"
]
=
ineligible
.
ReasonMessage
}
}
}
// updateAccountQuota 更新账户的配额信息
func
(
r
*
AntigravityQuotaRefresher
)
updateAccountQuota
(
account
*
Account
,
modelsResp
*
antigravity
.
FetchAvailableModelsResponse
)
{
if
account
.
Extra
==
nil
{
account
.
Extra
=
make
(
map
[
string
]
any
)
}
quota
:=
make
(
map
[
string
]
any
)
for
modelName
,
modelInfo
:=
range
modelsResp
.
Models
{
if
modelInfo
.
QuotaInfo
==
nil
{
continue
}
// 转换 remainingFraction (0.0-1.0) 为百分比 (0-100)
remaining
:=
int
(
modelInfo
.
QuotaInfo
.
RemainingFraction
*
100
)
quota
[
modelName
]
=
map
[
string
]
any
{
"remaining"
:
remaining
,
"reset_time"
:
modelInfo
.
QuotaInfo
.
ResetTime
,
}
}
account
.
Extra
[
"quota"
]
=
quota
account
.
Extra
[
"last_quota_check"
]
=
time
.
Now
()
.
Format
(
time
.
RFC3339
)
}
backend/internal/service/antigravity_token_provider.go
0 → 100644
View file @
e83f0ee3
package
service
import
(
"context"
"errors"
"log"
"strconv"
"strings"
"time"
)
const
(
antigravityTokenRefreshSkew
=
3
*
time
.
Minute
antigravityTokenCacheSkew
=
5
*
time
.
Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
AntigravityTokenCache
=
GeminiTokenCache
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
type
AntigravityTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
AntigravityTokenCache
antigravityOAuthService
*
AntigravityOAuthService
}
func
NewAntigravityTokenProvider
(
accountRepo
AccountRepository
,
tokenCache
AntigravityTokenCache
,
antigravityOAuthService
*
AntigravityOAuthService
,
)
*
AntigravityTokenProvider
{
return
&
AntigravityTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
antigravityOAuthService
:
antigravityOAuthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
AntigravityTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAntigravity
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an antigravity oauth account"
)
}
cacheKey
:=
antigravityTokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
}
// 2. 如果即将过期则刷新
expiresAt
:=
parseAntigravityExpiresAt
(
account
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
antigravityTokenRefreshSkew
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
parseAntigravityExpiresAt
(
account
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
antigravityTokenRefreshSkew
{
if
p
.
antigravityOAuthService
==
nil
{
return
""
,
errors
.
New
(
"antigravity oauth service not configured"
)
}
tokenInfo
,
err
:=
p
.
antigravityOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
err
}
newCredentials
:=
p
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
log
.
Printf
(
"[AntigravityTokenProvider] Failed to update account credentials: %v"
,
updateErr
)
}
expiresAt
=
parseAntigravityExpiresAt
(
account
)
}
}
}
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
antigravityTokenCacheSkew
:
ttl
=
until
-
antigravityTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
return
accessToken
,
nil
}
func
antigravityTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
return
"ag:"
+
projectID
}
return
"ag:account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
func
parseAntigravityExpiresAt
(
account
*
Account
)
*
time
.
Time
{
raw
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"expires_at"
))
if
raw
==
""
{
return
nil
}
if
unixSec
,
err
:=
strconv
.
ParseInt
(
raw
,
10
,
64
);
err
==
nil
&&
unixSec
>
0
{
t
:=
time
.
Unix
(
unixSec
,
0
)
return
&
t
}
if
t
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
raw
);
err
==
nil
{
return
&
t
}
return
nil
}
backend/internal/service/antigravity_token_refresher.go
0 → 100644
View file @
e83f0ee3
package
service
import
(
"context"
"strconv"
"time"
)
// AntigravityTokenRefresher 实现 TokenRefresher 接口
type
AntigravityTokenRefresher
struct
{
antigravityOAuthService
*
AntigravityOAuthService
}
func
NewAntigravityTokenRefresher
(
antigravityOAuthService
*
AntigravityOAuthService
)
*
AntigravityTokenRefresher
{
return
&
AntigravityTokenRefresher
{
antigravityOAuthService
:
antigravityOAuthService
,
}
}
// CanRefresh 检查是否可以刷新此账户
func
(
r
*
AntigravityTokenRefresher
)
CanRefresh
(
account
*
Account
)
bool
{
return
account
.
Platform
==
PlatformAntigravity
&&
account
.
Type
==
AccountTypeOAuth
}
// NeedsRefresh 检查账户是否需要刷新
func
(
r
*
AntigravityTokenRefresher
)
NeedsRefresh
(
account
*
Account
,
refreshWindow
time
.
Duration
)
bool
{
if
!
r
.
CanRefresh
(
account
)
{
return
false
}
expiresAtStr
:=
account
.
GetCredential
(
"expires_at"
)
if
expiresAtStr
==
""
{
return
false
}
expiresAt
,
err
:=
strconv
.
ParseInt
(
expiresAtStr
,
10
,
64
)
if
err
!=
nil
{
return
false
}
expiryTime
:=
time
.
Unix
(
expiresAt
,
0
)
return
time
.
Until
(
expiryTime
)
<
refreshWindow
}
// Refresh 执行 token 刷新
func
(
r
*
AntigravityTokenRefresher
)
Refresh
(
ctx
context
.
Context
,
account
*
Account
)
(
map
[
string
]
any
,
error
)
{
tokenInfo
,
err
:=
r
.
antigravityOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
newCredentials
:=
r
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
return
newCredentials
,
nil
}
backend/internal/service/domain_constants.go
View file @
e83f0ee3
...
...
@@ -21,6 +21,7 @@ const (
PlatformAnthropic
=
"anthropic"
PlatformOpenAI
=
"openai"
PlatformGemini
=
"gemini"
PlatformAntigravity
=
"antigravity"
)
// Account type constants
...
...
backend/internal/service/gateway_multiplatform_test.go
0 → 100644
View file @
e83f0ee3
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// testConfig 返回一个用于测试的默认配置
func
testConfig
()
*
config
.
Config
{
return
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
}
// mockAccountRepoForPlatform 单平台测试用的 mock
type
mockAccountRepoForPlatform
struct
{
accounts
[]
Account
accountsByID
map
[
int64
]
*
Account
listPlatformFunc
func
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
}
func
(
m
*
mockAccountRepoForPlatform
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
if
acc
,
ok
:=
m
.
accountsByID
[
id
];
ok
{
return
acc
,
nil
}
return
nil
,
errors
.
New
(
"account not found"
)
}
func
(
m
*
mockAccountRepoForPlatform
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
if
m
.
listPlatformFunc
!=
nil
{
return
m
.
listPlatformFunc
(
ctx
,
platform
)
}
var
result
[]
Account
for
_
,
acc
:=
range
m
.
accounts
{
if
acc
.
Platform
==
platform
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
{
return
m
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
// Stub methods to implement AccountRepository interface
func
(
m
*
mockAccountRepoForPlatform
)
Create
(
ctx
context
.
Context
,
account
*
Account
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListActive
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListSchedulable
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
var
result
[]
Account
platformSet
:=
make
(
map
[
string
]
bool
)
for
_
,
p
:=
range
platforms
{
platformSet
[
p
]
=
true
}
for
_
,
acc
:=
range
m
.
accounts
{
if
platformSet
[
acc
.
Platform
]
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
m
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
func
(
m
*
mockAccountRepoForPlatform
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
AccountBulkUpdate
)
(
int64
,
error
)
{
return
0
,
nil
}
// Verify interface implementation
var
_
AccountRepository
=
(
*
mockAccountRepoForPlatform
)(
nil
)
// mockGatewayCacheForPlatform 单平台测试用的 cache mock
type
mockGatewayCacheForPlatform
struct
{
sessionBindings
map
[
string
]
int64
}
func
(
m
*
mockGatewayCacheForPlatform
)
GetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
)
(
int64
,
error
)
{
if
id
,
ok
:=
m
.
sessionBindings
[
sessionHash
];
ok
{
return
id
,
nil
}
return
0
,
errors
.
New
(
"not found"
)
}
func
(
m
*
mockGatewayCacheForPlatform
)
SetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
if
m
.
sessionBindings
==
nil
{
m
.
sessionBindings
=
make
(
map
[
string
]
int64
)
}
m
.
sessionBindings
[
sessionHash
]
=
accountID
return
nil
}
func
(
m
*
mockGatewayCacheForPlatform
)
RefreshSessionTTL
(
ctx
context
.
Context
,
sessionHash
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
}
// TestGatewayService_SelectAccountForModelWithPlatform_Anthropic 测试 anthropic 单平台选择
func
TestGatewayService_SelectAccountForModelWithPlatform_Anthropic
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
3
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 应被隔离
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"应选择优先级最高的 anthropic 账户"
)
require
.
Equal
(
t
,
PlatformAnthropic
,
acc
.
Platform
,
"应只返回 anthropic 平台账户"
)
}
// TestGatewayService_SelectAccountForModelWithPlatform_Antigravity 测试 antigravity 单平台选择
func
TestGatewayService_SelectAccountForModelWithPlatform_Antigravity
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 应被隔离
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAntigravity
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
require
.
Equal
(
t
,
PlatformAntigravity
,
acc
.
Platform
,
"应只返回 antigravity 平台账户"
)
}
// TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed 测试优先级和最后使用时间
func
TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
now
:=
time
.
Now
()
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
ptr
(
now
.
Add
(
-
1
*
time
.
Hour
))},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
ptr
(
now
.
Add
(
-
2
*
time
.
Hour
))},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级应选择最久未用的账户"
)
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func
TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{},
accountsByID
:
map
[
int64
]
*
Account
{},
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
}
// TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded 测试所有账户被排除
func
TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{},
2
:
{}}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
,
PlatformAnthropic
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
}
// TestGatewayService_SelectAccountForModelWithPlatform_Schedulability 测试账户可调度性检查
func
TestGatewayService_SelectAccountForModelWithPlatform_Schedulability
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
now
:=
time
.
Now
()
tests
:=
[]
struct
{
name
string
accounts
[]
Account
expectedID
int64
}{
{
name
:
"过载账户被跳过"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
OverloadUntil
:
ptr
(
now
.
Add
(
1
*
time
.
Hour
))},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
2
,
},
{
name
:
"限流账户被跳过"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
RateLimitResetAt
:
ptr
(
now
.
Add
(
1
*
time
.
Hour
))},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
2
,
},
{
name
:
"非active账户被跳过"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
"error"
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
2
,
},
{
name
:
"schedulable=false被跳过"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
false
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
2
,
},
{
name
:
"过期的过载账户可调度"
,
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
OverloadUntil
:
ptr
(
now
.
Add
(
-
1
*
time
.
Hour
))},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
expectedID
:
1
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
tt
.
accounts
,
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
tt
.
expectedID
,
acc
.
ID
)
})
}
}
// TestGatewayService_SelectAccountForModelWithPlatform_StickySession 测试粘性会话
func
TestGatewayService_SelectAccountForModelWithPlatform_StickySession
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"粘性会话命中-同平台"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"应返回粘性会话绑定的账户"
)
})
t
.
Run
(
"粘性会话不匹配平台-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 粘性会话绑定但平台不匹配
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
// 绑定 antigravity 账户
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
// 请求 anthropic 平台,但粘性会话绑定的是 antigravity 账户
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话账户平台不匹配,应降级选择同平台账户"
)
require
.
Equal
(
t
,
PlatformAnthropic
,
acc
.
Platform
)
})
t
.
Run
(
"粘性会话账户被排除-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话账户被排除,应选择其他账户"
)
})
t
.
Run
(
"粘性会话账户不可调度-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
"error"
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话账户不可调度,应选择其他账户"
)
})
}
func
TestGatewayService_isModelSupportedByAccount
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
tests
:=
[]
struct
{
name
string
account
*
Account
model
string
expected
bool
}{
{
name
:
"Antigravity平台-支持claude模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-支持gemini模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gemini-2.5-flash"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-不支持gpt模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gpt-4"
,
expected
:
false
,
},
{
name
:
"Anthropic平台-无映射配置-支持所有模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
{
name
:
"Anthropic平台-有映射配置-只支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-opus-4"
:
"x"
}},
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
false
,
},
{
name
:
"Anthropic平台-有映射配置-支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-sonnet-20241022"
:
"x"
}},
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
isModelSupportedByAccount
(
tt
.
account
,
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
// TestGatewayService_selectAccountWithMixedScheduling 测试混合调度
func
TestGatewayService_selectAccountWithMixedScheduling
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"混合调度-包含启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"应选择优先级最高的账户(包含启用混合调度的antigravity)"
)
})
t
.
Run
(
"混合调度-过滤未启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 未启用 mixed_scheduling
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"未启用mixed_scheduling的antigravity账户应被过滤"
)
require
.
Equal
(
t
,
PlatformAnthropic
,
acc
.
Platform
)
})
t
.
Run
(
"混合调度-粘性会话命中启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
2
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"应返回粘性会话绑定的启用mixed_scheduling的antigravity账户"
)
})
t
.
Run
(
"混合调度-粘性会话命中未启用mixed_scheduling的antigravity账户-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 未启用 mixed_scheduling
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
2
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"粘性会话绑定的账户未启用mixed_scheduling,应降级选择anthropic账户"
)
})
t
.
Run
(
"混合调度-仅有启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
require
.
Equal
(
t
,
PlatformAntigravity
,
acc
.
Platform
)
})
t
.
Run
(
"混合调度-无可用账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 未启用 mixed_scheduling
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
}
// TestAccount_IsMixedSchedulingEnabled 测试混合调度开关检查
func
TestAccount_IsMixedSchedulingEnabled
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
Account
expected
bool
}{
{
name
:
"非antigravity平台-返回false"
,
account
:
Account
{
Platform
:
PlatformAnthropic
},
expected
:
false
,
},
{
name
:
"antigravity平台-无extra-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
},
expected
:
false
,
},
{
name
:
"antigravity平台-extra无mixed_scheduling-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{}},
expected
:
false
,
},
{
name
:
"antigravity平台-mixed_scheduling=false-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
false
}},
expected
:
false
,
},
{
name
:
"antigravity平台-mixed_scheduling=true-返回true"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
expected
:
true
,
},
{
name
:
"antigravity平台-mixed_scheduling非bool类型-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
"true"
}},
expected
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
tt
.
account
.
IsMixedSchedulingEnabled
()
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
backend/internal/service/gateway_service.go
View file @
e83f0ee3
...
...
@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
...
...
@@ -93,6 +94,7 @@ func (e *UpstreamFailoverError) Error() string {
// GatewayService handles API gateway operations
type
GatewayService
struct
{
accountRepo
AccountRepository
groupRepo
GroupRepository
usageLogRepo
UsageLogRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
...
...
@@ -109,6 +111,7 @@ type GatewayService struct {
// NewGatewayService creates a new GatewayService
func
NewGatewayService
(
accountRepo
AccountRepository
,
groupRepo
GroupRepository
,
usageLogRepo
UsageLogRepository
,
userRepo
UserRepository
,
userSubRepo
UserSubscriptionRepository
,
...
...
@@ -123,6 +126,7 @@ func NewGatewayService(
)
*
GatewayService
{
return
&
GatewayService
{
accountRepo
:
accountRepo
,
groupRepo
:
groupRepo
,
usageLogRepo
:
usageLogRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
...
...
@@ -291,16 +295,53 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func
(
s
*
GatewayService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
// 优先检查 context 中的强制平台(/antigravity 路由)
var
platform
string
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
!=
""
{
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
platform
=
group
.
Platform
}
else
{
// 无分组时只使用原生 anthropic 平台
platform
=
PlatformAnthropic
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if
(
platform
==
PlatformAnthropic
||
platform
==
PlatformGemini
)
&&
!
hasForcePlatform
{
return
s
.
selectAccountWithMixedScheduling
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
platform
)
}
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
if
hasForcePlatform
&&
groupID
!=
nil
{
account
,
err
:=
s
.
selectAccountForModelWithPlatform
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
platform
)
if
err
==
nil
{
return
account
,
nil
}
// 分组中找不到,回退查询全部该平台账户
groupID
=
nil
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
return
s
.
selectAccountForModelWithPlatform
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
platform
)
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
// 1. 查询粘性会话
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
sessionHash
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
// 同时检查模型支持
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
// 续期粘性会话
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
if
err
==
nil
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulable
()
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
...
...
@@ -310,16 +351,16 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
}
// 2. 获取可调度账号列表(
排除限流和过载的账号,仅限 Anthropic
平台)
// 2. 获取可调度账号列表(
单
平台)
var
accounts
[]
Account
var
err
error
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
// 简易模式:忽略 groupID,查询所有可用账号
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
P
latform
Anthropic
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
p
latform
)
}
else
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
P
latform
Anthropic
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
p
latform
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
P
latform
Anthropic
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
p
latform
)
}
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
...
...
@@ -332,19 +373,16 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// 检查模型支持
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
if
selected
==
nil
{
selected
=
acc
continue
}
// 优先选择priority值更小的(priority值越小优先级越高)
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
// 优先级相同时,选最久未用的
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
...
...
@@ -377,6 +415,126 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return
selected
,
nil
}
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func
(
s
*
GatewayService
)
selectAccountWithMixedScheduling
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
nativePlatform
string
)
(
*
Account
,
error
)
{
platforms
:=
[]
string
{
nativePlatform
,
PlatformAntigravity
}
// 1. 查询粘性会话
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
sessionHash
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
return
account
,
nil
}
}
}
}
}
// 2. 获取可调度账号列表
var
accounts
[]
Account
var
err
error
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatforms
(
ctx
,
*
groupID
,
platforms
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var
selected
*
Account
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if
acc
.
Platform
==
PlatformAntigravity
&&
!
acc
.
IsMixedSchedulingEnabled
()
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
if
selected
==
nil
{
selected
=
acc
continue
}
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
case
acc
.
LastUsedAt
!=
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (never used is preferred)
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (both never used)
default
:
if
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
)
{
selected
=
acc
}
}
}
}
if
selected
==
nil
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available accounts supporting model: %s"
,
requestedModel
)
}
return
nil
,
errors
.
New
(
"no available accounts"
)
}
// 4. 建立粘性绑定
if
sessionHash
!=
""
{
if
err
:=
s
.
cache
.
SetSessionAccountID
(
ctx
,
sessionHash
,
selected
.
ID
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"set session account failed: session=%s account_id=%d err=%v"
,
sessionHash
,
selected
.
ID
,
err
)
}
}
return
selected
,
nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func
(
s
*
GatewayService
)
isModelSupportedByAccount
(
account
*
Account
,
requestedModel
string
)
bool
{
if
account
.
Platform
==
PlatformAntigravity
{
// Antigravity 平台使用专门的模型支持检查
return
IsAntigravityModelSupported
(
requestedModel
)
}
// 其他平台使用账户的模型支持检查
return
account
.
IsModelSupported
(
requestedModel
)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
func
IsAntigravityModelSupported
(
requestedModel
string
)
bool
{
// 直接支持的模型
if
antigravitySupportedModels
[
requestedModel
]
{
return
true
}
// 可映射的模型
if
_
,
ok
:=
antigravityModelMapping
[
requestedModel
];
ok
{
return
true
}
// Gemini 前缀透传
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
true
}
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5)
if
strings
.
HasPrefix
(
requestedModel
,
"claude-"
)
{
return
true
}
return
false
}
// GetAccessToken 获取账号凭证
func
(
s
*
GatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
switch
account
.
Type
{
...
...
@@ -1116,6 +1274,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func
(
s
*
GatewayService
)
ForwardCountTokens
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
error
{
// Antigravity 账户不支持 count_tokens 转发,返回估算值
// 参考 Antigravity-Manager 和 proxycast 实现
if
account
.
Platform
==
PlatformAntigravity
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"input_tokens"
:
100
})
return
nil
}
// 应用模型映射(仅对 apikey 类型账号)
if
account
.
Type
==
AccountTypeApiKey
{
var
req
struct
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
e83f0ee3
...
...
@@ -18,6 +18,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
...
...
@@ -34,25 +35,31 @@ const (
type
GeminiMessagesCompatService
struct
{
accountRepo
AccountRepository
groupRepo
GroupRepository
cache
GatewayCache
tokenProvider
*
GeminiTokenProvider
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
antigravityGatewayService
*
AntigravityGatewayService
}
func
NewGeminiMessagesCompatService
(
accountRepo
AccountRepository
,
groupRepo
GroupRepository
,
cache
GatewayCache
,
tokenProvider
*
GeminiTokenProvider
,
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
antigravityGatewayService
*
AntigravityGatewayService
,
)
*
GeminiMessagesCompatService
{
return
&
GeminiMessagesCompatService
{
accountRepo
:
accountRepo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
tokenProvider
:
tokenProvider
,
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
antigravityGatewayService
:
antigravityGatewayService
,
}
}
...
...
@@ -66,26 +73,71 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
}
func
(
s
*
GeminiMessagesCompatService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
// 优先检查 context 中的强制平台(/antigravity 路由)
var
platform
string
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
!=
""
{
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
platform
=
group
.
Platform
}
else
{
// 无分组时只使用原生 gemini 平台
platform
=
PlatformGemini
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling
:=
platform
==
PlatformGemini
&&
!
hasForcePlatform
var
queryPlatforms
[]
string
if
useMixedScheduling
{
queryPlatforms
=
[]
string
{
PlatformGemini
,
PlatformAntigravity
}
}
else
{
queryPlatforms
=
[]
string
{
platform
}
}
cacheKey
:=
"gemini:"
+
sessionHash
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
cacheKey
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
account
.
Platform
==
PlatformGemini
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
valid
:=
false
if
account
.
Platform
==
platform
{
valid
=
true
}
else
if
useMixedScheduling
&&
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
()
{
valid
=
true
}
if
valid
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
cacheKey
,
geminiStickySessionTTL
)
return
account
,
nil
}
}
}
}
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
var
accounts
[]
Account
var
err
error
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
PlatformGemini
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatforms
(
ctx
,
*
groupID
,
queryPlatforms
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if
len
(
accounts
)
==
0
&&
hasForcePlatform
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatforms
(
ctx
,
queryPlatforms
)
}
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
Platform
Gemini
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
s
(
ctx
,
query
Platform
s
)
}
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
...
...
@@ -97,7 +149,12 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
// 非混合调度模式(antigravity 分组):不需要过滤
if
useMixedScheduling
&&
acc
.
Platform
==
PlatformAntigravity
&&
!
acc
.
IsMixedSchedulingEnabled
()
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
if
selected
==
nil
{
...
...
@@ -139,6 +196,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return
selected
,
nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func
(
s
*
GeminiMessagesCompatService
)
isModelSupportedByAccount
(
account
*
Account
,
requestedModel
string
)
bool
{
if
account
.
Platform
==
PlatformAntigravity
{
return
IsAntigravityModelSupported
(
requestedModel
)
}
return
account
.
IsModelSupported
(
requestedModel
)
}
// GetAntigravityGatewayService 返回 AntigravityGatewayService
func
(
s
*
GeminiMessagesCompatService
)
GetAntigravityGatewayService
()
*
AntigravityGatewayService
{
return
s
.
antigravityGatewayService
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func
(
s
*
GeminiMessagesCompatService
)
HasAntigravityAccounts
(
ctx
context
.
Context
,
groupID
*
int64
)
(
bool
,
error
)
{
var
accounts
[]
Account
var
err
error
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
PlatformAntigravity
)
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
PlatformAntigravity
)
}
if
err
!=
nil
{
return
false
,
err
}
return
len
(
accounts
)
>
0
,
nil
}
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
//
...
...
@@ -1798,7 +1883,7 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
if
statusCode
!=
429
{
return
}
resetAt
:=
p
arseGeminiRateLimitResetTime
(
body
)
resetAt
:=
P
arseGeminiRateLimitResetTime
(
body
)
if
resetAt
==
nil
{
ra
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
ra
)
...
...
@@ -1807,7 +1892,8 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
time
.
Unix
(
*
resetAt
,
0
))
}
func
parseGeminiRateLimitResetTime
(
body
[]
byte
)
*
int64
{
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
func
ParseGeminiRateLimitResetTime
(
body
[]
byte
)
*
int64
{
// Try to parse metadata.quotaResetDelay like "12.345s"
var
parsed
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
parsed
);
err
==
nil
{
...
...
backend/internal/service/gemini_multiplatform_test.go
0 → 100644
View file @
e83f0ee3
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// mockAccountRepoForGemini Gemini 测试用的 mock
type
mockAccountRepoForGemini
struct
{
accounts
[]
Account
accountsByID
map
[
int64
]
*
Account
}
func
(
m
*
mockAccountRepoForGemini
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
if
acc
,
ok
:=
m
.
accountsByID
[
id
];
ok
{
return
acc
,
nil
}
return
nil
,
errors
.
New
(
"account not found"
)
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
var
result
[]
Account
for
_
,
acc
:=
range
m
.
accounts
{
if
acc
.
Platform
==
platform
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
Account
,
error
)
{
// 测试时不区分 groupID,直接按 platform 过滤
return
m
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
// Stub methods to implement AccountRepository interface
func
(
m
*
mockAccountRepoForGemini
)
Create
(
ctx
context
.
Context
,
account
*
Account
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListActive
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
BatchUpdateLastUsed
(
ctx
context
.
Context
,
updates
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulable
(
ctx
context
.
Context
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
var
result
[]
Account
platformSet
:=
make
(
map
[
string
]
bool
)
for
_
,
p
:=
range
platforms
{
platformSet
[
p
]
=
true
}
for
_
,
acc
:=
range
m
.
accounts
{
if
platformSet
[
acc
.
Platform
]
&&
acc
.
IsSchedulable
()
{
result
=
append
(
result
,
acc
)
}
}
return
result
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
m
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
func
(
m
*
mockAccountRepoForGemini
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
AccountBulkUpdate
)
(
int64
,
error
)
{
return
0
,
nil
}
// Verify interface implementation
var
_
AccountRepository
=
(
*
mockAccountRepoForGemini
)(
nil
)
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
type
mockGroupRepoForGemini
struct
{
groups
map
[
int64
]
*
Group
}
func
(
m
*
mockGroupRepoForGemini
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
errors
.
New
(
"group not found"
)
}
// Stub methods to implement GroupRepository interface
func
(
m
*
mockGroupRepoForGemini
)
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGemini
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGemini
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGemini
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
ListActive
(
ctx
context
.
Context
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
var
_
GroupRepository
=
(
*
mockGroupRepoForGemini
)(
nil
)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type
mockGatewayCacheForGemini
struct
{
sessionBindings
map
[
string
]
int64
}
func
(
m
*
mockGatewayCacheForGemini
)
GetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
)
(
int64
,
error
)
{
if
id
,
ok
:=
m
.
sessionBindings
[
sessionHash
];
ok
{
return
id
,
nil
}
return
0
,
errors
.
New
(
"not found"
)
}
func
(
m
*
mockGatewayCacheForGemini
)
SetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
if
m
.
sessionBindings
==
nil
{
m
.
sessionBindings
=
make
(
map
[
string
]
int64
)
}
m
.
sessionBindings
[
sessionHash
]
=
accountID
return
nil
}
func
(
m
*
mockGatewayCacheForGemini
)
RefreshSessionTTL
(
ctx
context
.
Context
,
sessionHash
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
3
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 应被隔离
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
// 无分组时使用 gemini 平台
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"应选择优先级最高的 gemini 账户"
)
require
.
Equal
(
t
,
PlatformGemini
,
acc
.
Platform
,
"无分组时应只返回 gemini 平台账户"
)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 应被隔离
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 应被选择
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{
1
:
{
ID
:
1
,
Platform
:
PlatformAntigravity
},
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
groupID
:=
int64
(
1
)
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
require
.
Equal
(
t
,
PlatformAntigravity
,
acc
.
Platform
,
"antigravity 分组应只返回 antigravity 账户"
)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred 测试 OAuth 优先
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeApiKey
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
nil
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
nil
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级且都未使用时,应优先选择 OAuth 账户"
)
require
.
Equal
(
t
,
AccountTypeOAuth
,
acc
.
Type
)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts 测试无可用账户
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{},
accountsByID
:
map
[
int64
]
*
Account
{},
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available"
)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession 测试粘性会话
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"粘性会话命中-同平台"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
// 注意:缓存键使用 "gemini:" 前缀
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"gemini:session-123"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-123"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"应返回粘性会话绑定的账户"
)
})
t
.
Run
(
"粘性会话平台不匹配-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 粘性会话绑定
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"gemini:session-123"
:
1
},
// 绑定 antigravity 账户
}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
// 无分组时使用 gemini 平台,粘性会话绑定的 antigravity 账户平台不匹配
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-123"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话账户平台不匹配,应降级选择 gemini 账户"
)
require
.
Equal
(
t
,
PlatformGemini
,
acc
.
Platform
)
})
t
.
Run
(
"粘性会话不命中无前缀缓存键"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
// 缓存键没有 "gemini:" 前缀,不应命中
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-123"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
// 粘性会话未命中,按优先级选择
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话未命中,应按优先级选择"
)
})
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
func
TestGeminiPlatformRouting_DocumentRouteDecision
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
platform
string
expectedService
string
// "gemini" 表示 ForwardNative, "antigravity" 表示 ForwardGemini
}{
{
name
:
"Gemini平台走ForwardNative"
,
platform
:
PlatformGemini
,
expectedService
:
"gemini"
,
},
{
name
:
"Antigravity平台走ForwardGemini"
,
platform
:
PlatformAntigravity
,
expectedService
:
"antigravity"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
tt
.
platform
}
// 模拟 Handler 层的路由逻辑
var
serviceName
string
if
account
.
Platform
==
PlatformAntigravity
{
serviceName
=
"antigravity"
}
else
{
serviceName
=
"gemini"
}
require
.
Equal
(
t
,
tt
.
expectedService
,
serviceName
,
"平台 %s 应该路由到 %s 服务"
,
tt
.
platform
,
tt
.
expectedService
)
})
}
}
func
TestGeminiMessagesCompatService_isModelSupportedByAccount
(
t
*
testing
.
T
)
{
svc
:=
&
GeminiMessagesCompatService
{}
tests
:=
[]
struct
{
name
string
account
*
Account
model
string
expected
bool
}{
{
name
:
"Antigravity平台-支持gemini模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gemini-2.5-flash"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-支持claude模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-不支持gpt模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gpt-4"
,
expected
:
false
,
},
{
name
:
"Gemini平台-无映射配置-支持所有模型"
,
account
:
&
Account
{
Platform
:
PlatformGemini
},
model
:
"gemini-2.5-flash"
,
expected
:
true
,
},
{
name
:
"Gemini平台-有映射配置-只支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-1.5-pro"
:
"x"
}},
},
model
:
"gemini-2.5-flash"
,
expected
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
isModelSupportedByAccount
(
tt
.
account
,
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment