Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7b83d6e7
Commit
7b83d6e7
authored
Apr 05, 2026
by
陈曦
Browse files
Merge remote-tracking branch 'upstream/main'
parents
daa2e6df
dbb248df
Changes
106
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/openai_gateway_chat_completions.go
View file @
7b83d6e7
...
@@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
...
@@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
// 2. Resolve model mapping early so compat prompt_cache_key injection can
// 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family.
// derive a stable seed from the final upstream model family.
billingModel
:=
resolveOpenAIForwardModel
(
account
,
originalModel
,
defaultMappedModel
)
billingModel
:=
resolveOpenAIForwardModel
(
account
,
originalModel
,
defaultMappedModel
)
upstreamModel
:=
resolveOpenAIUpstream
Model
(
billingModel
)
upstreamModel
:=
normalizeCodex
Model
(
billingModel
)
promptCacheKey
=
strings
.
TrimSpace
(
promptCacheKey
)
promptCacheKey
=
strings
.
TrimSpace
(
promptCacheKey
)
compatPromptCacheInjected
:=
false
compatPromptCacheInjected
:=
false
...
...
backend/internal/service/openai_gateway_messages.go
View file @
7b83d6e7
...
@@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
...
@@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
// 3. Model mapping
// 3. Model mapping
billingModel
:=
resolveOpenAIForwardModel
(
account
,
normalizedModel
,
defaultMappedModel
)
billingModel
:=
resolveOpenAIForwardModel
(
account
,
normalizedModel
,
defaultMappedModel
)
upstreamModel
:=
resolveOpenAIUpstream
Model
(
billingModel
)
upstreamModel
:=
normalizeCodex
Model
(
billingModel
)
responsesReq
.
Model
=
upstreamModel
responsesReq
.
Model
=
upstreamModel
logger
.
L
()
.
Debug
(
"openai messages: model mapping applied"
,
logger
.
L
()
.
Debug
(
"openai messages: model mapping applied"
,
...
...
backend/internal/service/openai_gateway_record_usage_test.go
View file @
7b83d6e7
...
@@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
...
@@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil
,
nil
,
&
DeferredService
{},
&
DeferredService
{},
nil
,
nil
,
nil
,
nil
,
)
)
svc
.
userGroupRateResolver
=
newUserGroupRateResolver
(
svc
.
userGroupRateResolver
=
newUserGroupRateResolver
(
rateRepo
,
rateRepo
,
...
...
backend/internal/service/openai_gateway_service.go
View file @
7b83d6e7
...
@@ -10,6 +10,7 @@ import (
...
@@ -10,6 +10,7 @@ import (
"errors"
"errors"
"fmt"
"fmt"
"io"
"io"
"log/slog"
"math/rand"
"math/rand"
"net/http"
"net/http"
"sort"
"sort"
...
@@ -204,6 +205,7 @@ type OpenAIUsage struct {
...
@@ -204,6 +205,7 @@ type OpenAIUsage struct {
OutputTokens
int
`json:"output_tokens"`
OutputTokens
int
`json:"output_tokens"`
CacheCreationInputTokens
int
`json:"cache_creation_input_tokens,omitempty"`
CacheCreationInputTokens
int
`json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens
int
`json:"cache_read_input_tokens,omitempty"`
CacheReadInputTokens
int
`json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens
int
`json:"image_output_tokens,omitempty"`
}
}
// OpenAIForwardResult represents the result of forwarding
// OpenAIForwardResult represents the result of forwarding
...
@@ -322,6 +324,8 @@ type OpenAIGatewayService struct {
...
@@ -322,6 +324,8 @@ type OpenAIGatewayService struct {
openAITokenProvider
*
OpenAITokenProvider
openAITokenProvider
*
OpenAITokenProvider
toolCorrector
*
CodexToolCorrector
toolCorrector
*
CodexToolCorrector
openaiWSResolver
OpenAIWSProtocolResolver
openaiWSResolver
OpenAIWSProtocolResolver
resolver
*
ModelPricingResolver
channelService
*
ChannelService
openaiWSPoolOnce
sync
.
Once
openaiWSPoolOnce
sync
.
Once
openaiWSStateStoreOnce
sync
.
Once
openaiWSStateStoreOnce
sync
.
Once
...
@@ -357,6 +361,8 @@ func NewOpenAIGatewayService(
...
@@ -357,6 +361,8 @@ func NewOpenAIGatewayService(
httpUpstream
HTTPUpstream
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
deferredService
*
DeferredService
,
openAITokenProvider
*
OpenAITokenProvider
,
openAITokenProvider
*
OpenAITokenProvider
,
resolver
*
ModelPricingResolver
,
channelService
*
ChannelService
,
)
*
OpenAIGatewayService
{
)
*
OpenAIGatewayService
{
svc
:=
&
OpenAIGatewayService
{
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
...
@@ -384,6 +390,8 @@ func NewOpenAIGatewayService(
...
@@ -384,6 +390,8 @@ func NewOpenAIGatewayService(
openAITokenProvider
:
openAITokenProvider
,
openAITokenProvider
:
openAITokenProvider
,
toolCorrector
:
NewCodexToolCorrector
(),
toolCorrector
:
NewCodexToolCorrector
(),
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
openaiWSResolver
:
NewOpenAIWSProtocolResolver
(
cfg
),
resolver
:
resolver
,
channelService
:
channelService
,
responseHeaderFilter
:
compileResponseHeaderFilter
(
cfg
),
responseHeaderFilter
:
compileResponseHeaderFilter
(
cfg
),
codexSnapshotThrottle
:
newAccountWriteThrottle
(
openAICodexSnapshotPersistMinInterval
),
codexSnapshotThrottle
:
newAccountWriteThrottle
(
openAICodexSnapshotPersistMinInterval
),
}
}
...
@@ -391,6 +399,74 @@ func NewOpenAIGatewayService(
...
@@ -391,6 +399,74 @@ func NewOpenAIGatewayService(
return
svc
return
svc
}
}
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService)
func
(
s
*
OpenAIGatewayService
)
ResolveChannelMapping
(
ctx
context
.
Context
,
groupID
int64
,
model
string
)
ChannelMappingResult
{
if
s
.
channelService
==
nil
{
return
ChannelMappingResult
{
MappedModel
:
model
}
}
return
s
.
channelService
.
ResolveChannelMapping
(
ctx
,
groupID
,
model
)
}
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService)
func
(
s
*
OpenAIGatewayService
)
IsModelRestricted
(
ctx
context
.
Context
,
groupID
int64
,
model
string
)
bool
{
if
s
.
channelService
==
nil
{
return
false
}
return
s
.
channelService
.
IsModelRestricted
(
ctx
,
groupID
,
model
)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 模型限制检查已移至调度阶段,restricted 始终返回 false。
func
(
s
*
OpenAIGatewayService
)
ResolveChannelMappingAndRestrict
(
ctx
context
.
Context
,
groupID
*
int64
,
model
string
)
(
ChannelMappingResult
,
bool
)
{
if
s
.
channelService
==
nil
{
return
ChannelMappingResult
{
MappedModel
:
model
},
false
}
return
s
.
channelService
.
ResolveChannelMappingAndRestrict
(
ctx
,
groupID
,
model
)
}
func
(
s
*
OpenAIGatewayService
)
checkChannelPricingRestriction
(
ctx
context
.
Context
,
groupID
*
int64
,
requestedModel
string
)
bool
{
if
groupID
==
nil
||
s
.
channelService
==
nil
||
requestedModel
==
""
{
return
false
}
mapping
:=
s
.
channelService
.
ResolveChannelMapping
(
ctx
,
*
groupID
,
requestedModel
)
billingModel
:=
billingModelForRestriction
(
mapping
.
BillingModelSource
,
requestedModel
,
mapping
.
MappedModel
)
if
billingModel
==
""
{
return
false
}
return
s
.
channelService
.
IsModelRestricted
(
ctx
,
*
groupID
,
billingModel
)
}
func
(
s
*
OpenAIGatewayService
)
isUpstreamModelRestrictedByChannel
(
ctx
context
.
Context
,
groupID
int64
,
account
*
Account
,
requestedModel
string
)
bool
{
if
s
.
channelService
==
nil
{
return
false
}
upstreamModel
:=
resolveOpenAIForwardModel
(
account
,
requestedModel
,
""
)
if
upstreamModel
==
""
{
return
false
}
return
s
.
channelService
.
IsModelRestricted
(
ctx
,
groupID
,
upstreamModel
)
}
func
(
s
*
OpenAIGatewayService
)
needsUpstreamChannelRestrictionCheck
(
ctx
context
.
Context
,
groupID
*
int64
)
bool
{
if
groupID
==
nil
||
s
.
channelService
==
nil
{
return
false
}
ch
,
err
:=
s
.
channelService
.
GetChannelForGroup
(
ctx
,
*
groupID
)
if
err
!=
nil
{
slog
.
Warn
(
"failed to check openai channel upstream restriction"
,
"group_id"
,
*
groupID
,
"error"
,
err
)
return
false
}
if
ch
==
nil
||
!
ch
.
RestrictModels
{
return
false
}
return
ch
.
BillingModelSource
==
BillingModelSourceUpstream
}
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
func
(
s
*
OpenAIGatewayService
)
ReplaceModelInBody
(
body
[]
byte
,
newModel
string
)
[]
byte
{
return
ReplaceModelInBody
(
body
,
newModel
)
}
func
(
s
*
OpenAIGatewayService
)
getCodexSnapshotThrottle
()
*
accountWriteThrottle
{
func
(
s
*
OpenAIGatewayService
)
getCodexSnapshotThrottle
()
*
accountWriteThrottle
{
if
s
!=
nil
&&
s
.
codexSnapshotThrottle
!=
nil
{
if
s
!=
nil
&&
s
.
codexSnapshotThrottle
!=
nil
{
return
s
.
codexSnapshotThrottle
return
s
.
codexSnapshotThrottle
...
@@ -1125,6 +1201,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
...
@@ -1125,6 +1201,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
}
}
func
(
s
*
OpenAIGatewayService
)
selectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
stickyAccountID
int64
)
(
*
Account
,
error
)
{
func
(
s
*
OpenAIGatewayService
)
selectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
stickyAccountID
int64
)
(
*
Account
,
error
)
{
if
s
.
checkChannelPricingRestriction
(
ctx
,
groupID
,
requestedModel
)
{
slog
.
Warn
(
"channel pricing restriction blocked request"
,
"group_id"
,
derefGroupID
(
groupID
),
"model"
,
requestedModel
)
return
nil
,
fmt
.
Errorf
(
"%w supporting model: %s (channel pricing restriction)"
,
ErrNoAvailableAccounts
,
requestedModel
)
}
// 1. 尝试粘性会话命中
// 1. 尝试粘性会话命中
// Try sticky session hit
// Try sticky session hit
if
account
:=
s
.
tryStickySessionHit
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
stickyAccountID
);
account
!=
nil
{
if
account
:=
s
.
tryStickySessionHit
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
,
stickyAccountID
);
account
!=
nil
{
...
@@ -1140,7 +1223,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
...
@@ -1140,7 +1223,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 3. 按优先级 + LRU 选择最佳账号
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
// Select by priority + LRU
selected
:=
s
.
selectBestAccount
(
ctx
,
accounts
,
requestedModel
,
excludedIDs
)
selected
:=
s
.
selectBestAccount
(
ctx
,
groupID
,
accounts
,
requestedModel
,
excludedIDs
)
if
selected
==
nil
{
if
selected
==
nil
{
if
requestedModel
!=
""
{
if
requestedModel
!=
""
{
...
@@ -1206,6 +1289,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
...
@@ -1206,6 +1289,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
return
nil
return
nil
}
}
if
groupID
!=
nil
&&
s
.
needsUpstreamChannelRestrictionCheck
(
ctx
,
groupID
)
&&
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
account
,
requestedModel
)
{
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
return
nil
}
// 刷新会话 TTL 并返回账号
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
// Refresh session TTL and return account
...
@@ -1218,8 +1306,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
...
@@ -1218,8 +1306,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
//
//
// selectBestAccount selects the best account from candidates (priority + LRU).
// selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account.
// Returns nil if no available account.
func
(
s
*
OpenAIGatewayService
)
selectBestAccount
(
ctx
context
.
Context
,
accounts
[]
Account
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
*
Account
{
func
(
s
*
OpenAIGatewayService
)
selectBestAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
accounts
[]
Account
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
*
Account
{
var
selected
*
Account
var
selected
*
Account
needsUpstreamCheck
:=
s
.
needsUpstreamChannelRestrictionCheck
(
ctx
,
groupID
)
for
i
:=
range
accounts
{
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
acc
:=
&
accounts
[
i
]
...
@@ -1238,6 +1327,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
...
@@ -1238,6 +1327,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if
fresh
==
nil
{
if
fresh
==
nil
{
continue
continue
}
}
if
needsUpstreamCheck
&&
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
fresh
,
requestedModel
)
{
continue
}
// 选择优先级最高且最久未使用的账号
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
// Select highest priority and least recently used
...
@@ -1289,7 +1381,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
...
@@ -1289,7 +1381,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func
(
s
*
OpenAIGatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
AccountSelectionResult
,
error
)
{
func
(
s
*
OpenAIGatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
AccountSelectionResult
,
error
)
{
if
s
.
checkChannelPricingRestriction
(
ctx
,
groupID
,
requestedModel
)
{
slog
.
Warn
(
"channel pricing restriction blocked request"
,
"group_id"
,
derefGroupID
(
groupID
),
"model"
,
requestedModel
)
return
nil
,
fmt
.
Errorf
(
"%w supporting model: %s (channel pricing restriction)"
,
ErrNoAvailableAccounts
,
requestedModel
)
}
cfg
:=
s
.
schedulingConfig
()
cfg
:=
s
.
schedulingConfig
()
needsUpstreamCheck
:=
s
.
needsUpstreamChannelRestrictionCheck
(
ctx
,
groupID
)
var
stickyAccountID
int64
var
stickyAccountID
int64
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
accountID
,
err
:=
s
.
getStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
);
err
==
nil
{
if
accountID
,
err
:=
s
.
getStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
);
err
==
nil
{
...
@@ -1365,6 +1465,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1365,6 +1465,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
account
=
s
.
recheckSelectedOpenAIAccountFromDB
(
ctx
,
account
,
requestedModel
)
if
account
==
nil
{
if
account
==
nil
{
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
}
else
if
needsUpstreamCheck
&&
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
account
,
requestedModel
)
{
_
=
s
.
deleteStickySessionAccountID
(
ctx
,
groupID
,
sessionHash
)
}
else
{
}
else
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
...
@@ -1410,6 +1512,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1410,6 +1512,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
continue
continue
}
}
if
needsUpstreamCheck
&&
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
acc
,
requestedModel
)
{
continue
}
candidates
=
append
(
candidates
,
acc
)
candidates
=
append
(
candidates
,
acc
)
}
}
...
@@ -1434,6 +1539,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1434,6 +1539,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if
fresh
==
nil
{
if
fresh
==
nil
{
continue
continue
}
}
if
needsUpstreamCheck
&&
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
fresh
,
requestedModel
)
{
continue
}
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
fresh
.
ID
,
fresh
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
fresh
.
ID
,
fresh
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
if
sessionHash
!=
""
{
if
sessionHash
!=
""
{
...
@@ -1488,6 +1596,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1488,6 +1596,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if
fresh
==
nil
{
if
fresh
==
nil
{
continue
continue
}
}
if
needsUpstreamCheck
&&
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
fresh
,
requestedModel
)
{
continue
}
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
fresh
.
ID
,
fresh
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
fresh
.
ID
,
fresh
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
if
sessionHash
!=
""
{
if
sessionHash
!=
""
{
...
@@ -1510,6 +1621,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
...
@@ -1510,6 +1621,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if
fresh
==
nil
{
if
fresh
==
nil
{
continue
continue
}
}
if
needsUpstreamCheck
&&
s
.
isUpstreamModelRestrictedByChannel
(
ctx
,
*
groupID
,
fresh
,
requestedModel
)
{
continue
}
return
&
AccountSelectionResult
{
return
&
AccountSelectionResult
{
Account
:
fresh
,
Account
:
fresh
,
WaitPlan
:
&
AccountWaitPlan
{
WaitPlan
:
&
AccountWaitPlan
{
...
@@ -1825,7 +1939,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
...
@@ -1825,7 +1939,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if
model
,
ok
:=
reqBody
[
"model"
]
.
(
string
);
ok
{
if
model
,
ok
:=
reqBody
[
"model"
]
.
(
string
);
ok
{
upstreamModel
=
resolveOpenAIUpstream
Model
(
model
)
upstreamModel
=
normalizeCodex
Model
(
model
)
if
upstreamModel
!=
""
&&
upstreamModel
!=
model
{
if
upstreamModel
!=
""
&&
upstreamModel
!=
model
{
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)"
,
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)"
,
model
,
upstreamModel
,
account
.
Name
,
account
.
Type
,
isCodexCLI
)
model
,
upstreamModel
,
account
.
Name
,
account
.
Type
,
isCodexCLI
)
...
@@ -4110,6 +4224,7 @@ type OpenAIRecordUsageInput struct {
...
@@ -4110,6 +4224,7 @@ type OpenAIRecordUsageInput struct {
IPAddress
string
// 请求的客户端 IP 地址
IPAddress
string
// 请求的客户端 IP 地址
RequestPayloadHash
string
RequestPayloadHash
string
APIKeyService
APIKeyQuotaUpdater
APIKeyService
APIKeyQuotaUpdater
ChannelUsageFields
}
}
// RecordUsage records usage and deducts balance
// RecordUsage records usage and deducts balance
...
@@ -4140,10 +4255,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
...
@@ -4140,10 +4255,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
OutputTokens
:
result
.
Usage
.
OutputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
ImageOutputTokens
:
result
.
Usage
.
ImageOutputTokens
,
}
}
// Get rate multiplier
// Get rate multiplier
multiplier
:=
s
.
cfg
.
Default
.
RateMultiplier
multiplier
:=
1.0
if
s
.
cfg
!=
nil
{
multiplier
=
s
.
cfg
.
Default
.
RateMultiplier
}
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
resolver
:=
s
.
userGroupRateResolver
resolver
:=
s
.
userGroupRateResolver
if
resolver
==
nil
{
if
resolver
==
nil
{
...
@@ -4152,12 +4271,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
...
@@ -4152,12 +4271,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier
=
resolver
.
Resolve
(
ctx
,
user
.
ID
,
*
apiKey
.
GroupID
,
apiKey
.
Group
.
RateMultiplier
)
multiplier
=
resolver
.
Resolve
(
ctx
,
user
.
ID
,
*
apiKey
.
GroupID
,
apiKey
.
Group
.
RateMultiplier
)
}
}
var
cost
*
CostBreakdown
var
err
error
billingModel
:=
forwardResultBillingModel
(
result
.
Model
,
result
.
UpstreamModel
)
billingModel
:=
forwardResultBillingModel
(
result
.
Model
,
result
.
UpstreamModel
)
if
result
.
BillingModel
!=
""
{
billingModel
=
strings
.
TrimSpace
(
result
.
BillingModel
)
}
if
input
.
BillingModelSource
==
BillingModelSourceChannelMapped
&&
input
.
ChannelMappedModel
!=
""
{
billingModel
=
input
.
ChannelMappedModel
}
if
input
.
BillingModelSource
==
BillingModelSourceRequested
&&
input
.
OriginalModel
!=
""
{
billingModel
=
input
.
OriginalModel
}
serviceTier
:=
""
serviceTier
:=
""
if
result
.
ServiceTier
!=
nil
{
if
result
.
ServiceTier
!=
nil
{
serviceTier
=
strings
.
TrimSpace
(
*
result
.
ServiceTier
)
serviceTier
=
strings
.
TrimSpace
(
*
result
.
ServiceTier
)
}
}
cost
,
err
:=
s
.
billingService
.
CalculateCostWithServiceTier
(
billingModel
,
tokens
,
multiplier
,
serviceTier
)
if
s
.
resolver
!=
nil
&&
apiKey
.
Group
!=
nil
{
gid
:=
apiKey
.
Group
.
ID
cost
,
err
=
s
.
billingService
.
CalculateCostUnified
(
CostInput
{
Ctx
:
ctx
,
Model
:
billingModel
,
GroupID
:
&
gid
,
Tokens
:
tokens
,
RequestCount
:
1
,
RateMultiplier
:
multiplier
,
ServiceTier
:
serviceTier
,
Resolver
:
s
.
resolver
,
})
}
else
{
cost
,
err
=
s
.
billingService
.
CalculateCostWithServiceTier
(
billingModel
,
tokens
,
multiplier
,
serviceTier
)
}
if
err
!=
nil
{
if
err
!=
nil
{
cost
=
&
CostBreakdown
{
ActualCost
:
0
}
cost
=
&
CostBreakdown
{
ActualCost
:
0
}
}
}
...
@@ -4173,36 +4317,58 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
...
@@ -4173,36 +4317,58 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
requestID
:=
resolveUsageBillingRequestID
(
ctx
,
result
.
RequestID
)
requestID
:=
resolveUsageBillingRequestID
(
ctx
,
result
.
RequestID
)
// 确定 RequestedModel(渠道映射前的原始模型)
requestedModel
:=
result
.
Model
if
input
.
OriginalModel
!=
""
{
requestedModel
=
input
.
OriginalModel
}
usageLog
:=
&
UsageLog
{
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
requestID
,
RequestID
:
requestID
,
Model
:
result
.
Model
,
Model
:
result
.
Model
,
RequestedModel
:
result
.
Model
,
RequestedModel
:
requestedModel
,
UpstreamModel
:
optionalNonEqualStringPtr
(
result
.
UpstreamModel
,
result
.
Model
),
UpstreamModel
:
optionalNonEqualStringPtr
(
result
.
UpstreamModel
,
result
.
Model
),
ServiceTier
:
result
.
ServiceTier
,
ServiceTier
:
result
.
ServiceTier
,
ReasoningEffort
:
result
.
ReasoningEffort
,
ReasoningEffort
:
result
.
ReasoningEffort
,
InboundEndpoint
:
optionalTrimmedStringPtr
(
input
.
InboundEndpoint
),
InboundEndpoint
:
optionalTrimmedStringPtr
(
input
.
InboundEndpoint
),
UpstreamEndpoint
:
optionalTrimmedStringPtr
(
input
.
UpstreamEndpoint
),
UpstreamEndpoint
:
optionalTrimmedStringPtr
(
input
.
UpstreamEndpoint
),
InputTokens
:
actualInputTokens
,
InputTokens
:
actualInputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
InputCost
:
cost
.
InputCost
,
ImageOutputTokens
:
result
.
Usage
.
ImageOutputTokens
,
OutputCost
:
cost
.
OutputCost
,
}
CacheCreationCost
:
cost
.
CacheCreationCost
,
if
cost
!=
nil
{
CacheReadCost
:
cost
.
CacheReadCost
,
usageLog
.
InputCost
=
cost
.
InputCost
TotalCost
:
cost
.
TotalCost
,
usageLog
.
OutputCost
=
cost
.
OutputCost
ActualCost
:
cost
.
ActualCost
,
usageLog
.
ImageOutputCost
=
cost
.
ImageOutputCost
RateMultiplier
:
multiplier
,
usageLog
.
CacheCreationCost
=
cost
.
CacheCreationCost
AccountRateMultiplier
:
&
accountRateMultiplier
,
usageLog
.
CacheReadCost
=
cost
.
CacheReadCost
BillingType
:
billingType
,
usageLog
.
TotalCost
=
cost
.
TotalCost
Stream
:
result
.
Stream
,
usageLog
.
ActualCost
=
cost
.
ActualCost
OpenAIWSMode
:
result
.
OpenAIWSMode
,
}
DurationMs
:
&
durationMs
,
usageLog
.
RateMultiplier
=
multiplier
FirstTokenMs
:
result
.
FirstTokenMs
,
usageLog
.
AccountRateMultiplier
=
&
accountRateMultiplier
CreatedAt
:
time
.
Now
(),
usageLog
.
BillingType
=
billingType
usageLog
.
Stream
=
result
.
Stream
usageLog
.
OpenAIWSMode
=
result
.
OpenAIWSMode
usageLog
.
DurationMs
=
&
durationMs
usageLog
.
FirstTokenMs
=
result
.
FirstTokenMs
usageLog
.
CreatedAt
=
time
.
Now
()
// 设置渠道信息
usageLog
.
ChannelID
=
optionalInt64Ptr
(
input
.
ChannelID
)
usageLog
.
ModelMappingChain
=
optionalTrimmedStringPtr
(
input
.
ModelMappingChain
)
// 设置计费模式
if
cost
!=
nil
&&
cost
.
BillingMode
!=
""
{
billingMode
:=
cost
.
BillingMode
usageLog
.
BillingMode
=
&
billingMode
}
else
{
billingMode
:=
string
(
BillingModeToken
)
usageLog
.
BillingMode
=
&
billingMode
}
}
// 添加 UserAgent
// 添加 UserAgent
if
input
.
UserAgent
!=
""
{
if
input
.
UserAgent
!=
""
{
...
...
backend/internal/service/openai_model_mapping.go
View file @
7b83d6e7
package
service
package
service
import
"strings"
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// resolveOpenAIForwardModel resolves the account/group mapping result for
// did not match any explicit model_mapping rule.
// OpenAI-compatible forwarding. Group-level default mapping only applies when
// the account itself did not match any explicit model_mapping rule.
func
resolveOpenAIForwardModel
(
account
*
Account
,
requestedModel
,
defaultMappedModel
string
)
string
{
func
resolveOpenAIForwardModel
(
account
*
Account
,
requestedModel
,
defaultMappedModel
string
)
string
{
if
account
==
nil
{
if
account
==
nil
{
if
defaultMappedModel
!=
""
{
if
defaultMappedModel
!=
""
{
...
@@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
...
@@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
}
}
return
mappedModel
return
mappedModel
}
}
func
resolveOpenAIUpstreamModel
(
model
string
)
string
{
if
isBareGPT53CodexSparkModel
(
model
)
{
return
"gpt-5.3-codex-spark"
}
return
normalizeCodexModel
(
strings
.
TrimSpace
(
model
))
}
func
isBareGPT53CodexSparkModel
(
model
string
)
bool
{
modelID
:=
strings
.
TrimSpace
(
model
)
if
modelID
==
""
{
return
false
}
if
strings
.
Contains
(
modelID
,
"/"
)
{
parts
:=
strings
.
Split
(
modelID
,
"/"
)
modelID
=
parts
[
len
(
parts
)
-
1
]
}
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
modelID
))
return
normalized
==
"gpt-5.3-codex-spark"
||
normalized
==
"gpt 5.3 codex spark"
}
backend/internal/service/openai_model_mapping_test.go
View file @
7b83d6e7
...
@@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
...
@@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
Credentials
:
map
[
string
]
any
{},
Credentials
:
map
[
string
]
any
{},
}
}
withoutDefault
:=
resolveOpenAIUpstream
Model
(
resolveOpenAIForwardModel
(
account
,
"claude-opus-4-6"
,
""
))
withoutDefault
:=
normalizeCodex
Model
(
resolveOpenAIForwardModel
(
account
,
"claude-opus-4-6"
,
""
))
if
withoutDefault
!=
"gpt-5.1"
{
if
withoutDefault
!=
"gpt-5.1"
{
t
.
Fatalf
(
"
resolveOpenAIUpstream
Model(...) = %q, want %q"
,
withoutDefault
,
"gpt-5.1"
)
t
.
Fatalf
(
"
normalizeCodex
Model(...) = %q, want %q"
,
withoutDefault
,
"gpt-5.1"
)
}
}
withDefault
:=
resolveOpenAIUpstream
Model
(
resolveOpenAIForwardModel
(
account
,
"claude-opus-4-6"
,
"gpt-5.4"
))
withDefault
:=
normalizeCodex
Model
(
resolveOpenAIForwardModel
(
account
,
"claude-opus-4-6"
,
"gpt-5.4"
))
if
withDefault
!=
"gpt-5.4"
{
if
withDefault
!=
"gpt-5.4"
{
t
.
Fatalf
(
"
resolveOpenAIUpstream
Model(...) = %q, want %q"
,
withDefault
,
"gpt-5.4"
)
t
.
Fatalf
(
"
normalizeCodex
Model(...) = %q, want %q"
,
withDefault
,
"gpt-5.4"
)
}
}
}
}
func
Test
ResolveOpenAIUpstream
Model
(
t
*
testing
.
T
)
{
func
Test
NormalizeCodex
Model
(
t
*
testing
.
T
)
{
cases
:=
map
[
string
]
string
{
cases
:=
map
[
string
]
string
{
"gpt-5.3-codex-spark"
:
"gpt-5.3-codex-spark"
,
"gpt-5.3-codex-spark"
:
"gpt-5.3-codex"
,
"gpt 5.3 codex spark"
:
"gpt-5.3-codex-spark"
,
"gpt-5.3-codex-spark-high"
:
"gpt-5.3-codex"
,
" openai/gpt-5.3-codex-spark "
:
"gpt-5.3-codex-spark"
,
"gpt-5.3-codex-spark-xhigh"
:
"gpt-5.3-codex"
,
"gpt-5.3-codex-spark-high"
:
"gpt-5.3-codex"
,
"gpt-5.3"
:
"gpt-5.3-codex"
,
"gpt-5.3-codex-spark-xhigh"
:
"gpt-5.3-codex"
,
"gpt-5.3"
:
"gpt-5.3-codex"
,
}
}
for
input
,
expected
:=
range
cases
{
for
input
,
expected
:=
range
cases
{
if
got
:=
resolveOpenAIUpstream
Model
(
input
);
got
!=
expected
{
if
got
:=
normalizeCodex
Model
(
input
);
got
!=
expected
{
t
.
Fatalf
(
"
resolveOpenAIUpstream
Model(%q) = %q, want %q"
,
input
,
got
,
expected
)
t
.
Fatalf
(
"
normalizeCodex
Model(%q) = %q, want %q"
,
input
,
got
,
expected
)
}
}
}
}
}
}
backend/internal/service/openai_ws_forwarder.go
View file @
7b83d6e7
...
@@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
...
@@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
}
normalized
=
next
normalized
=
next
}
}
upstreamModel
:=
resolveOpenAIUpstream
Model
(
account
.
GetMappedModel
(
originalModel
))
upstreamModel
:=
normalizeCodex
Model
(
account
.
GetMappedModel
(
originalModel
))
if
upstreamModel
!=
originalModel
{
if
upstreamModel
!=
originalModel
{
next
,
setErr
:=
applyPayloadMutation
(
normalized
,
"model"
,
upstreamModel
)
next
,
setErr
:=
applyPayloadMutation
(
normalized
,
"model"
,
upstreamModel
)
if
setErr
!=
nil
{
if
setErr
!=
nil
{
...
@@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
...
@@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
mappedModel
:=
""
mappedModel
:=
""
var
mappedModelBytes
[]
byte
var
mappedModelBytes
[]
byte
if
originalModel
!=
""
{
if
originalModel
!=
""
{
mappedModel
=
resolveOpenAIUpstream
Model
(
account
.
GetMappedModel
(
originalModel
))
mappedModel
=
normalizeCodex
Model
(
account
.
GetMappedModel
(
originalModel
))
needModelReplace
=
mappedModel
!=
""
&&
mappedModel
!=
originalModel
needModelReplace
=
mappedModel
!=
""
&&
mappedModel
!=
originalModel
if
needModelReplace
{
if
needModelReplace
{
mappedModelBytes
=
[]
byte
(
mappedModel
)
mappedModelBytes
=
[]
byte
(
mappedModel
)
...
...
backend/internal/service/openai_ws_protocol_forward_test.go
View file @
7b83d6e7
...
@@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
...
@@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
)
decision
:=
svc
.
getOpenAIWSProtocolResolver
()
.
Resolve
(
nil
)
decision
:=
svc
.
getOpenAIWSProtocolResolver
()
.
Resolve
(
nil
)
...
...
backend/internal/service/ops_retry.go
View file @
7b83d6e7
...
@@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
...
@@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if
s
.
gatewayService
==
nil
{
if
s
.
gatewayService
==
nil
{
return
nil
,
fmt
.
Errorf
(
"gateway service not available"
)
return
nil
,
fmt
.
Errorf
(
"gateway service not available"
)
}
}
return
s
.
gatewayService
.
SelectAccountWithLoadAwareness
(
ctx
,
groupID
,
""
,
model
,
excludedIDs
,
""
)
// 重试不使用会话限制
return
s
.
gatewayService
.
SelectAccountWithLoadAwareness
(
ctx
,
groupID
,
""
,
model
,
excludedIDs
,
""
,
int64
(
0
)
)
// 重试不使用会话限制
default
:
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported retry type: %s"
,
reqType
)
return
nil
,
fmt
.
Errorf
(
"unsupported retry type: %s"
,
reqType
)
}
}
...
...
backend/internal/service/pricing_service.go
View file @
7b83d6e7
...
@@ -70,7 +70,8 @@ type LiteLLMModelPricing struct {
...
@@ -70,7 +70,8 @@ type LiteLLMModelPricing struct {
LiteLLMProvider
string
`json:"litellm_provider"`
LiteLLMProvider
string
`json:"litellm_provider"`
Mode
string
`json:"mode"`
Mode
string
`json:"mode"`
SupportsPromptCaching
bool
`json:"supports_prompt_caching"`
SupportsPromptCaching
bool
`json:"supports_prompt_caching"`
OutputCostPerImage
float64
`json:"output_cost_per_image"`
// 图片生成模型每张图片价格
OutputCostPerImage
float64
`json:"output_cost_per_image"`
// 图片生成模型每张图片价格
OutputCostPerImageToken
float64
`json:"output_cost_per_image_token"`
// 图片输出 token 价格
}
}
// PricingRemoteClient 远程价格数据获取接口
// PricingRemoteClient 远程价格数据获取接口
...
@@ -94,6 +95,7 @@ type LiteLLMRawEntry struct {
...
@@ -94,6 +95,7 @@ type LiteLLMRawEntry struct {
Mode
string
`json:"mode"`
Mode
string
`json:"mode"`
SupportsPromptCaching
bool
`json:"supports_prompt_caching"`
SupportsPromptCaching
bool
`json:"supports_prompt_caching"`
OutputCostPerImage
*
float64
`json:"output_cost_per_image"`
OutputCostPerImage
*
float64
`json:"output_cost_per_image"`
OutputCostPerImageToken
*
float64
`json:"output_cost_per_image_token"`
}
}
// PricingService 动态价格服务
// PricingService 动态价格服务
...
@@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
...
@@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if
entry
.
OutputCostPerImage
!=
nil
{
if
entry
.
OutputCostPerImage
!=
nil
{
pricing
.
OutputCostPerImage
=
*
entry
.
OutputCostPerImage
pricing
.
OutputCostPerImage
=
*
entry
.
OutputCostPerImage
}
}
if
entry
.
OutputCostPerImageToken
!=
nil
{
pricing
.
OutputCostPerImageToken
=
*
entry
.
OutputCostPerImageToken
}
result
[
modelName
]
=
pricing
result
[
modelName
]
=
pricing
}
}
...
...
backend/internal/service/redeem_service.go
View file @
7b83d6e7
...
@@ -131,9 +131,9 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
...
@@ -131,9 +131,9 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
return
nil
,
errors
.
New
(
"count must be greater than 0"
)
return
nil
,
errors
.
New
(
"count must be greater than 0"
)
}
}
// 邀请码类型不需要数值,其他类型需要
// 邀请码类型不需要数值,其他类型需要
非零值(支持负数用于退款)
if
req
.
Type
!=
RedeemTypeInvitation
&&
req
.
Value
<
=
0
{
if
req
.
Type
!=
RedeemTypeInvitation
&&
req
.
Value
=
=
0
{
return
nil
,
errors
.
New
(
"value must
be greater than 0
"
)
return
nil
,
errors
.
New
(
"value must
not be zero
"
)
}
}
if
req
.
Count
>
1000
{
if
req
.
Count
>
1000
{
...
@@ -188,8 +188,8 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error
...
@@ -188,8 +188,8 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error
if
code
.
Type
==
""
{
if
code
.
Type
==
""
{
code
.
Type
=
RedeemTypeBalance
code
.
Type
=
RedeemTypeBalance
}
}
if
code
.
Type
!=
RedeemTypeInvitation
&&
code
.
Value
<
=
0
{
if
code
.
Type
!=
RedeemTypeInvitation
&&
code
.
Value
=
=
0
{
return
errors
.
New
(
"value must
be greater than 0
"
)
return
errors
.
New
(
"value must
not be zero
"
)
}
}
if
code
.
Status
==
""
{
if
code
.
Status
==
""
{
code
.
Status
=
StatusUnused
code
.
Status
=
StatusUnused
...
@@ -292,7 +292,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
...
@@ -292,7 +292,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
}
_
=
user
// 使用变量避免未使用错误
// 使用数据库事务保证兑换码标记与权益发放的原子性
// 使用数据库事务保证兑换码标记与权益发放的原子性
tx
,
err
:=
s
.
entClient
.
Tx
(
ctx
)
tx
,
err
:=
s
.
entClient
.
Tx
(
ctx
)
...
@@ -316,31 +315,46 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
...
@@ -316,31 +315,46 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch
redeemCode
.
Type
{
switch
redeemCode
.
Type
{
case
RedeemTypeBalance
:
case
RedeemTypeBalance
:
// 增加用户余额
amount
:=
redeemCode
.
Value
if
err
:=
s
.
userRepo
.
UpdateBalance
(
txCtx
,
userID
,
redeemCode
.
Value
);
err
!=
nil
{
// 负数为退款扣减,余额最低为 0
if
amount
<
0
&&
user
.
Balance
+
amount
<
0
{
amount
=
-
user
.
Balance
}
if
err
:=
s
.
userRepo
.
UpdateBalance
(
txCtx
,
userID
,
amount
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update user balance: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"update user balance: %w"
,
err
)
}
}
case
RedeemTypeConcurrency
:
case
RedeemTypeConcurrency
:
// 增加用户并发数
delta
:=
int
(
redeemCode
.
Value
)
if
err
:=
s
.
userRepo
.
UpdateConcurrency
(
txCtx
,
userID
,
int
(
redeemCode
.
Value
));
err
!=
nil
{
// 负数为退款扣减,并发数最低为 0
if
delta
<
0
&&
user
.
Concurrency
+
delta
<
0
{
delta
=
-
user
.
Concurrency
}
if
err
:=
s
.
userRepo
.
UpdateConcurrency
(
txCtx
,
userID
,
delta
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update user concurrency: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"update user concurrency: %w"
,
err
)
}
}
case
RedeemTypeSubscription
:
case
RedeemTypeSubscription
:
validityDays
:=
redeemCode
.
ValidityDays
validityDays
:=
redeemCode
.
ValidityDays
if
validityDays
<=
0
{
if
validityDays
<
0
{
validityDays
=
30
// 负数天数:缩短订阅,减到 0 则取消订阅
}
if
err
:=
s
.
reduceOrCancelSubscription
(
txCtx
,
userID
,
*
redeemCode
.
GroupID
,
-
validityDays
,
redeemCode
.
Code
);
err
!=
nil
{
_
,
_
,
err
:=
s
.
subscriptionService
.
AssignOrExtendSubscription
(
txCtx
,
&
AssignSubscriptionInput
{
return
nil
,
fmt
.
Errorf
(
"reduce or cancel subscription: %w"
,
err
)
UserID
:
userID
,
}
GroupID
:
*
redeemCode
.
GroupID
,
}
else
{
ValidityDays
:
validityDays
,
if
validityDays
==
0
{
AssignedBy
:
0
,
// 系统分配
validityDays
=
30
Notes
:
fmt
.
Sprintf
(
"通过兑换码 %s 兑换"
,
redeemCode
.
Code
),
}
})
_
,
_
,
err
:=
s
.
subscriptionService
.
AssignOrExtendSubscription
(
txCtx
,
&
AssignSubscriptionInput
{
if
err
!=
nil
{
UserID
:
userID
,
return
nil
,
fmt
.
Errorf
(
"assign or extend subscription: %w"
,
err
)
GroupID
:
*
redeemCode
.
GroupID
,
ValidityDays
:
validityDays
,
AssignedBy
:
0
,
// 系统分配
Notes
:
fmt
.
Sprintf
(
"通过兑换码 %s 兑换"
,
redeemCode
.
Code
),
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"assign or extend subscription: %w"
,
err
)
}
}
}
default
:
default
:
...
@@ -475,3 +489,51 @@ func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit
...
@@ -475,3 +489,51 @@ func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit
}
}
return
codes
,
nil
return
codes
,
nil
}
}
// reduceOrCancelSubscription 缩短订阅天数,剩余天数 <= 0 时取消订阅
func
(
s
*
RedeemService
)
reduceOrCancelSubscription
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
reduceDays
int
,
code
string
)
error
{
sub
,
err
:=
s
.
subscriptionService
.
userSubRepo
.
GetByUserIDAndGroupID
(
ctx
,
userID
,
groupID
)
if
err
!=
nil
{
return
ErrSubscriptionNotFound
}
now
:=
time
.
Now
()
remaining
:=
int
(
sub
.
ExpiresAt
.
Sub
(
now
)
.
Hours
()
/
24
)
if
remaining
<
0
{
remaining
=
0
}
notes
:=
fmt
.
Sprintf
(
"通过兑换码 %s 退款扣减 %d 天"
,
code
,
reduceDays
)
if
remaining
<=
reduceDays
{
// 剩余天数不足,直接取消订阅
if
err
:=
s
.
subscriptionService
.
userSubRepo
.
UpdateStatus
(
ctx
,
sub
.
ID
,
SubscriptionStatusExpired
);
err
!=
nil
{
return
fmt
.
Errorf
(
"cancel subscription: %w"
,
err
)
}
// 设置过期时间为当前时间
if
err
:=
s
.
subscriptionService
.
userSubRepo
.
ExtendExpiry
(
ctx
,
sub
.
ID
,
now
);
err
!=
nil
{
return
fmt
.
Errorf
(
"set subscription expiry: %w"
,
err
)
}
}
else
{
// 缩短天数
newExpiresAt
:=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
-
reduceDays
)
if
err
:=
s
.
subscriptionService
.
userSubRepo
.
ExtendExpiry
(
ctx
,
sub
.
ID
,
newExpiresAt
);
err
!=
nil
{
return
fmt
.
Errorf
(
"reduce subscription: %w"
,
err
)
}
}
// 追加备注
newNotes
:=
sub
.
Notes
if
newNotes
!=
""
{
newNotes
+=
"
\n
"
}
newNotes
+=
notes
if
err
:=
s
.
subscriptionService
.
userSubRepo
.
UpdateNotes
(
ctx
,
sub
.
ID
,
newNotes
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update subscription notes: %w"
,
err
)
}
// 失效缓存
s
.
subscriptionService
.
InvalidateSubCache
(
userID
,
groupID
)
return
nil
}
backend/internal/service/testhelpers_test.go
0 → 100644
View file @
7b83d6e7
//go:build unit
package
service
// testPtrFloat64 returns a pointer to the given float64 value.
func
testPtrFloat64
(
v
float64
)
*
float64
{
return
&
v
}
// testPtrInt returns a pointer to the given int value.
func
testPtrInt
(
v
int
)
*
int
{
return
&
v
}
// testPtrString returns a pointer to the given string value.
func
testPtrString
(
v
string
)
*
string
{
return
&
v
}
// testPtrBool returns a pointer to the given bool value.
func
testPtrBool
(
v
bool
)
*
bool
{
return
&
v
}
backend/internal/service/usage_log.go
View file @
7b83d6e7
...
@@ -104,6 +104,14 @@ type UsageLog struct {
...
@@ -104,6 +104,14 @@ type UsageLog struct {
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is).
// Nil means no mapping was applied (requested model was used as-is).
UpstreamModel
*
string
UpstreamModel
*
string
// ChannelID 渠道 ID
ChannelID
*
int64
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain
*
string
// BillingTier 计费层级标签(per_request/image 模式)
BillingTier
*
string
// BillingMode 计费模式:token/image(sora 路径为 nil)
BillingMode
*
string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier
*
string
ServiceTier
*
string
// ReasoningEffort is the request's reasoning effort level.
// ReasoningEffort is the request's reasoning effort level.
...
@@ -126,6 +134,9 @@ type UsageLog struct {
...
@@ -126,6 +134,9 @@ type UsageLog struct {
CacheCreation5mTokens
int
`gorm:"column:cache_creation_5m_tokens"`
CacheCreation5mTokens
int
`gorm:"column:cache_creation_5m_tokens"`
CacheCreation1hTokens
int
`gorm:"column:cache_creation_1h_tokens"`
CacheCreation1hTokens
int
`gorm:"column:cache_creation_1h_tokens"`
ImageOutputTokens
int
ImageOutputCost
float64
InputCost
float64
InputCost
float64
OutputCost
float64
OutputCost
float64
CacheCreationCost
float64
CacheCreationCost
float64
...
...
backend/internal/service/usage_log_helpers.go
View file @
7b83d6e7
...
@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
...
@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
}
}
return
strings
.
TrimSpace
(
upstreamModel
)
return
strings
.
TrimSpace
(
upstreamModel
)
}
}
func
optionalInt64Ptr
(
v
int64
)
*
int64
{
if
v
==
0
{
return
nil
}
return
&
v
}
backend/internal/service/wire.go
View file @
7b83d6e7
...
@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
...
@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
ProvideScheduledTestService
,
ProvideScheduledTestService
,
ProvideScheduledTestRunnerService
,
ProvideScheduledTestRunnerService
,
NewGroupCapacityService
,
NewGroupCapacityService
,
NewChannelService
,
NewModelPricingResolver
,
)
)
backend/migrations/081_create_channels.sql
0 → 100644
View file @
7b83d6e7
-- Create channels table for managing pricing channels.
-- A channel groups multiple groups together and provides custom model pricing.
SET
LOCAL
lock_timeout
=
'5s'
;
SET
LOCAL
statement_timeout
=
'10min'
;
-- 渠道表
CREATE
TABLE
IF
NOT
EXISTS
channels
(
id
BIGSERIAL
PRIMARY
KEY
,
name
VARCHAR
(
100
)
NOT
NULL
,
description
TEXT
DEFAULT
''
,
status
VARCHAR
(
20
)
NOT
NULL
DEFAULT
'active'
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
-- 渠道名称唯一索引
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
idx_channels_name
ON
channels
(
name
);
CREATE
INDEX
IF
NOT
EXISTS
idx_channels_status
ON
channels
(
status
);
-- 渠道-分组关联表(每个分组只能属于一个渠道)
CREATE
TABLE
IF
NOT
EXISTS
channel_groups
(
id
BIGSERIAL
PRIMARY
KEY
,
channel_id
BIGINT
NOT
NULL
REFERENCES
channels
(
id
)
ON
DELETE
CASCADE
,
group_id
BIGINT
NOT
NULL
REFERENCES
groups
(
id
)
ON
DELETE
CASCADE
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
idx_channel_groups_group_id
ON
channel_groups
(
group_id
);
CREATE
INDEX
IF
NOT
EXISTS
idx_channel_groups_channel_id
ON
channel_groups
(
channel_id
);
-- 渠道模型定价表(一条定价可绑定多个模型)
CREATE
TABLE
IF
NOT
EXISTS
channel_model_pricing
(
id
BIGSERIAL
PRIMARY
KEY
,
channel_id
BIGINT
NOT
NULL
REFERENCES
channels
(
id
)
ON
DELETE
CASCADE
,
models
JSONB
NOT
NULL
DEFAULT
'[]'
,
input_price
NUMERIC
(
20
,
12
),
output_price
NUMERIC
(
20
,
12
),
cache_write_price
NUMERIC
(
20
,
12
),
cache_read_price
NUMERIC
(
20
,
12
),
image_output_price
NUMERIC
(
20
,
8
),
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
INDEX
IF
NOT
EXISTS
idx_channel_model_pricing_channel_id
ON
channel_model_pricing
(
channel_id
);
COMMENT
ON
TABLE
channels
IS
'渠道管理:关联多个分组,提供自定义模型定价'
;
COMMENT
ON
TABLE
channel_groups
IS
'渠道-分组关联表:每个分组最多属于一个渠道'
;
COMMENT
ON
TABLE
channel_model_pricing
IS
'渠道模型定价:一条定价可绑定多个模型,价格一致'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
models
IS
'绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
input_price
IS
'每 token 输入价格(USD),NULL 表示使用默认'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
output_price
IS
'每 token 输出价格(USD),NULL 表示使用默认'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
cache_write_price
IS
'缓存写入每 token 价格,NULL 表示使用默认'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
cache_read_price
IS
'缓存读取每 token 价格,NULL 表示使用默认'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
image_output_price
IS
'图片输出价格(Gemini Image 等),NULL 表示使用默认'
;
backend/migrations/082_refactor_channel_pricing.sql
0 → 100644
View file @
7b83d6e7
-- Extend channel_model_pricing with billing_mode and add context-interval child table.
-- Supports three billing modes: token (per-token with context intervals),
-- per_request (per-request with context-size tiers), and image (per-image).
SET
LOCAL
lock_timeout
=
'5s'
;
SET
LOCAL
statement_timeout
=
'10min'
;
-- 1. 为 channel_model_pricing 添加 billing_mode 列
ALTER
TABLE
channel_model_pricing
ADD
COLUMN
IF
NOT
EXISTS
billing_mode
VARCHAR
(
20
)
NOT
NULL
DEFAULT
'token'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
billing_mode
IS
'计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)'
;
-- 2. 创建区间定价子表
CREATE
TABLE
IF
NOT
EXISTS
channel_pricing_intervals
(
id
BIGSERIAL
PRIMARY
KEY
,
pricing_id
BIGINT
NOT
NULL
REFERENCES
channel_model_pricing
(
id
)
ON
DELETE
CASCADE
,
min_tokens
INT
NOT
NULL
DEFAULT
0
,
max_tokens
INT
,
tier_label
VARCHAR
(
50
),
input_price
NUMERIC
(
20
,
12
),
output_price
NUMERIC
(
20
,
12
),
cache_write_price
NUMERIC
(
20
,
12
),
cache_read_price
NUMERIC
(
20
,
12
),
per_request_price
NUMERIC
(
20
,
12
),
sort_order
INT
NOT
NULL
DEFAULT
0
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
INDEX
IF
NOT
EXISTS
idx_channel_pricing_intervals_pricing_id
ON
channel_pricing_intervals
(
pricing_id
);
COMMENT
ON
TABLE
channel_pricing_intervals
IS
'渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
min_tokens
IS
'区间下界(含),token 模式使用'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
max_tokens
IS
'区间上界(不含),NULL 表示无上限'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
tier_label
IS
'层级标签,按次/图片模式使用(如 1K、2K、4K、HD)'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
input_price
IS
'token 模式:每 token 输入价'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
output_price
IS
'token 模式:每 token 输出价'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
cache_write_price
IS
'token 模式:缓存写入价'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
cache_read_price
IS
'token 模式:缓存读取价'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
per_request_price
IS
'按次/图片模式:每次请求价格'
;
-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目
INSERT
INTO
channel_pricing_intervals
(
pricing_id
,
min_tokens
,
max_tokens
,
input_price
,
output_price
,
cache_write_price
,
cache_read_price
,
sort_order
)
SELECT
cmp
.
id
,
0
,
NULL
,
cmp
.
input_price
,
cmp
.
output_price
,
cmp
.
cache_write_price
,
cmp
.
cache_read_price
,
0
FROM
channel_model_pricing
cmp
WHERE
cmp
.
billing_mode
=
'token'
AND
(
cmp
.
input_price
IS
NOT
NULL
OR
cmp
.
output_price
IS
NOT
NULL
OR
cmp
.
cache_write_price
IS
NOT
NULL
OR
cmp
.
cache_read_price
IS
NOT
NULL
)
AND
NOT
EXISTS
(
SELECT
1
FROM
channel_pricing_intervals
cpi
WHERE
cpi
.
pricing_id
=
cmp
.
id
);
-- 4. 迁移 image_output_price 为 image 模式的区间条目
-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留
-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理
backend/migrations/083_channel_model_mapping.sql
0 → 100644
View file @
7b83d6e7
SET
LOCAL
lock_timeout
=
'5s'
;
SET
LOCAL
statement_timeout
=
'10min'
;
ALTER
TABLE
channels
ADD
COLUMN
IF
NOT
EXISTS
model_mapping
JSONB
DEFAULT
'{}'
;
COMMENT
ON
COLUMN
channels
.
model_mapping
IS
'渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}'
;
backend/migrations/084_channel_billing_model_source.sql
0 → 100644
View file @
7b83d6e7
-- Add billing_model_source to channels (controls whether billing uses requested or upstream model)
ALTER
TABLE
channels
ADD
COLUMN
IF
NOT
EXISTS
billing_model_source
VARCHAR
(
20
)
DEFAULT
'requested'
;
-- Add channel tracking fields to usage_logs
ALTER
TABLE
usage_logs
ADD
COLUMN
IF
NOT
EXISTS
channel_id
BIGINT
;
ALTER
TABLE
usage_logs
ADD
COLUMN
IF
NOT
EXISTS
model_mapping_chain
VARCHAR
(
500
);
ALTER
TABLE
usage_logs
ADD
COLUMN
IF
NOT
EXISTS
billing_tier
VARCHAR
(
50
);
backend/migrations/085_channel_restrict_and_per_request_price.sql
0 → 100644
View file @
7b83d6e7
-- Add model restriction switch to channels
ALTER
TABLE
channels
ADD
COLUMN
IF
NOT
EXISTS
restrict_models
BOOLEAN
DEFAULT
false
;
-- Add default per_request_price to channel_model_pricing (fallback when no tier matches)
ALTER
TABLE
channel_model_pricing
ADD
COLUMN
IF
NOT
EXISTS
per_request_price
NUMERIC
(
20
,
10
);
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment