Commit 7b83d6e7 authored by 陈曦's avatar 陈曦
Browse files

Merge remote-tracking branch 'upstream/main'

parents daa2e6df dbb248df
...@@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( ...@@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
// 2. Resolve model mapping early so compat prompt_cache_key injection can // 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family. // derive a stable seed from the final upstream model family.
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel) upstreamModel := normalizeCodexModel(billingModel)
promptCacheKey = strings.TrimSpace(promptCacheKey) promptCacheKey = strings.TrimSpace(promptCacheKey)
compatPromptCacheInjected := false compatPromptCacheInjected := false
......
...@@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
// 3. Model mapping // 3. Model mapping
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel) upstreamModel := normalizeCodexModel(billingModel)
responsesReq.Model = upstreamModel responsesReq.Model = upstreamModel
logger.L().Debug("openai messages: model mapping applied", logger.L().Debug("openai messages: model mapping applied",
......
...@@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U ...@@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil, nil,
&DeferredService{}, &DeferredService{},
nil, nil,
nil,
nil,
) )
svc.userGroupRateResolver = newUserGroupRateResolver( svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo, rateRepo,
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"math/rand" "math/rand"
"net/http" "net/http"
"sort" "sort"
...@@ -204,6 +205,7 @@ type OpenAIUsage struct { ...@@ -204,6 +205,7 @@ type OpenAIUsage struct {
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
} }
// OpenAIForwardResult represents the result of forwarding // OpenAIForwardResult represents the result of forwarding
...@@ -322,6 +324,8 @@ type OpenAIGatewayService struct { ...@@ -322,6 +324,8 @@ type OpenAIGatewayService struct {
openAITokenProvider *OpenAITokenProvider openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
channelService *ChannelService
openaiWSPoolOnce sync.Once openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once openaiWSStateStoreOnce sync.Once
...@@ -357,6 +361,8 @@ func NewOpenAIGatewayService( ...@@ -357,6 +361,8 @@ func NewOpenAIGatewayService(
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
deferredService *DeferredService, deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider, openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
channelService *ChannelService,
) *OpenAIGatewayService { ) *OpenAIGatewayService {
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -384,6 +390,8 @@ func NewOpenAIGatewayService( ...@@ -384,6 +390,8 @@ func NewOpenAIGatewayService(
openAITokenProvider: openAITokenProvider, openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(), toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
channelService: channelService,
responseHeaderFilter: compileResponseHeaderFilter(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
} }
...@@ -391,6 +399,74 @@ func NewOpenAIGatewayService( ...@@ -391,6 +399,74 @@ func NewOpenAIGatewayService(
return svc return svc
} }
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService)
func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService)
func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 模型限制检查已移至调度阶段,restricted 始终返回 false。
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
}
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
}
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
}
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
if billingModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
if upstreamModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
}
func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil {
slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err)
return false
}
if ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
}
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return ReplaceModelInBody(body, newModel)
}
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
if s != nil && s.codexSnapshotThrottle != nil { if s != nil && s.codexSnapshotThrottle != nil {
return s.codexSnapshotThrottle return s.codexSnapshotThrottle
...@@ -1125,6 +1201,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C ...@@ -1125,6 +1201,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
} }
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 1. 尝试粘性会话命中 // 1. 尝试粘性会话命中
// Try sticky session hit // Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
...@@ -1140,7 +1223,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C ...@@ -1140,7 +1223,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 3. 按优先级 + LRU 选择最佳账号 // 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU // Select by priority + LRU
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs)
if selected == nil { if selected == nil {
if requestedModel != "" { if requestedModel != "" {
...@@ -1206,6 +1289,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID ...@@ -1206,6 +1289,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil return nil
} }
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号 // 刷新会话 TTL 并返回账号
// Refresh session TTL and return account // Refresh session TTL and return account
...@@ -1218,8 +1306,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID ...@@ -1218,8 +1306,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// //
// selectBestAccount selects the best account from candidates (priority + LRU). // selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account. // Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account var selected *Account
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
...@@ -1238,6 +1327,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ ...@@ -1238,6 +1327,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
// 选择优先级最高且最久未使用的账号 // 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used // Select highest priority and least recently used
...@@ -1289,7 +1381,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool ...@@ -1289,7 +1381,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
...@@ -1365,6 +1465,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1365,6 +1465,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil { if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else { } else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
...@@ -1410,6 +1512,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1410,6 +1512,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
candidates = append(candidates, acc) candidates = append(candidates, acc)
} }
...@@ -1434,6 +1539,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1434,6 +1539,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" {
...@@ -1488,6 +1596,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1488,6 +1596,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" {
...@@ -1510,6 +1621,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1510,6 +1621,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: fresh, Account: fresh,
WaitPlan: &AccountWaitPlan{ WaitPlan: &AccountWaitPlan{
...@@ -1825,7 +1939,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -1825,7 +1939,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if model, ok := reqBody["model"].(string); ok { if model, ok := reqBody["model"].(string); ok {
upstreamModel = resolveOpenAIUpstreamModel(model) upstreamModel = normalizeCodexModel(model)
if upstreamModel != "" && upstreamModel != model { if upstreamModel != "" && upstreamModel != model {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, upstreamModel, account.Name, account.Type, isCodexCLI) model, upstreamModel, account.Name, account.Type, isCodexCLI)
...@@ -4110,6 +4224,7 @@ type OpenAIRecordUsageInput struct { ...@@ -4110,6 +4224,7 @@ type OpenAIRecordUsageInput struct {
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater APIKeyService APIKeyQuotaUpdater
ChannelUsageFields
} }
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
...@@ -4140,10 +4255,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4140,10 +4255,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
} }
// Get rate multiplier // Get rate multiplier
multiplier := s.cfg.Default.RateMultiplier multiplier := 1.0
if s.cfg != nil {
multiplier = s.cfg.Default.RateMultiplier
}
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
resolver := s.userGroupRateResolver resolver := s.userGroupRateResolver
if resolver == nil { if resolver == nil {
...@@ -4152,12 +4271,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4152,12 +4271,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
} }
var cost *CostBreakdown
var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
serviceTier := "" serviceTier := ""
if result.ServiceTier != nil { if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier) serviceTier = strings.TrimSpace(*result.ServiceTier)
} }
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
if err != nil { if err != nil {
cost = &CostBreakdown{ActualCost: 0} cost = &CostBreakdown{ActualCost: 0}
} }
...@@ -4173,36 +4317,58 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4173,36 +4317,58 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID) requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
// 确定 RequestedModel(渠道映射前的原始模型)
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model, RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier, ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: actualInputTokens, InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost, ImageOutputTokens: result.Usage.ImageOutputTokens,
OutputCost: cost.OutputCost, }
CacheCreationCost: cost.CacheCreationCost, if cost != nil {
CacheReadCost: cost.CacheReadCost, usageLog.InputCost = cost.InputCost
TotalCost: cost.TotalCost, usageLog.OutputCost = cost.OutputCost
ActualCost: cost.ActualCost, usageLog.ImageOutputCost = cost.ImageOutputCost
RateMultiplier: multiplier, usageLog.CacheCreationCost = cost.CacheCreationCost
AccountRateMultiplier: &accountRateMultiplier, usageLog.CacheReadCost = cost.CacheReadCost
BillingType: billingType, usageLog.TotalCost = cost.TotalCost
Stream: result.Stream, usageLog.ActualCost = cost.ActualCost
OpenAIWSMode: result.OpenAIWSMode, }
DurationMs: &durationMs, usageLog.RateMultiplier = multiplier
FirstTokenMs: result.FirstTokenMs, usageLog.AccountRateMultiplier = &accountRateMultiplier
CreatedAt: time.Now(), usageLog.BillingType = billingType
usageLog.Stream = result.Stream
usageLog.OpenAIWSMode = result.OpenAIWSMode
usageLog.DurationMs = &durationMs
usageLog.FirstTokenMs = result.FirstTokenMs
usageLog.CreatedAt = time.Now()
// 设置渠道信息
usageLog.ChannelID = optionalInt64Ptr(input.ChannelID)
usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain)
// 设置计费模式
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
} }
// 添加 UserAgent // 添加 UserAgent
if input.UserAgent != "" { if input.UserAgent != "" {
......
package service package service
import "strings" // resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// resolveOpenAIForwardModel resolves the account/group mapping result for // did not match any explicit model_mapping rule.
// OpenAI-compatible forwarding. Group-level default mapping only applies when
// the account itself did not match any explicit model_mapping rule.
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil { if account == nil {
if defaultMappedModel != "" { if defaultMappedModel != "" {
...@@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo ...@@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
} }
return mappedModel return mappedModel
} }
func resolveOpenAIUpstreamModel(model string) string {
if isBareGPT53CodexSparkModel(model) {
return "gpt-5.3-codex-spark"
}
return normalizeCodexModel(strings.TrimSpace(model))
}
func isBareGPT53CodexSparkModel(model string) bool {
modelID := strings.TrimSpace(model)
if modelID == "" {
return false
}
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
normalized := strings.ToLower(strings.TrimSpace(modelID))
return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark"
}
...@@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * ...@@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
Credentials: map[string]any{}, Credentials: map[string]any{},
} }
withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
if withoutDefault != "gpt-5.1" { if withoutDefault != "gpt-5.1" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1") t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
} }
withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
if withDefault != "gpt-5.4" { if withDefault != "gpt-5.4" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4") t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4")
} }
} }
func TestResolveOpenAIUpstreamModel(t *testing.T) { func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{ cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark": "gpt-5.3-codex",
"gpt 5.3 codex spark": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
" openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3": "gpt-5.3-codex",
} }
for input, expected := range cases { for input, expected := range cases {
if got := resolveOpenAIUpstreamModel(input); got != expected { if got := normalizeCodexModel(input); got != expected {
t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected) t.Fatalf("normalizeCodexModel(%q) = %q, want %q", input, got, expected)
} }
} }
} }
...@@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ...@@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
} }
normalized = next normalized = next
} }
upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) upstreamModel := normalizeCodexModel(account.GetMappedModel(originalModel))
if upstreamModel != originalModel { if upstreamModel != originalModel {
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
if setErr != nil { if setErr != nil {
...@@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ...@@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
mappedModel := "" mappedModel := ""
var mappedModelBytes []byte var mappedModelBytes []byte
if originalModel != "" { if originalModel != "" {
mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) mappedModel = normalizeCodexModel(account.GetMappedModel(originalModel))
needModelReplace = mappedModel != "" && mappedModel != originalModel needModelReplace = mappedModel != "" && mappedModel != originalModel
if needModelReplace { if needModelReplace {
mappedModelBytes = []byte(mappedModel) mappedModelBytes = []byte(mappedModel)
......
...@@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { ...@@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
nil,
) )
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
...@@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry ...@@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if s.gatewayService == nil { if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available") return nil, fmt.Errorf("gateway service not available")
} }
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制 return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制
default: default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType) return nil, fmt.Errorf("unsupported retry type: %s", reqType)
} }
......
...@@ -70,7 +70,8 @@ type LiteLLMModelPricing struct { ...@@ -70,7 +70,8 @@ type LiteLLMModelPricing struct {
LiteLLMProvider string `json:"litellm_provider"` LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"` Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"` SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
OutputCostPerImageToken float64 `json:"output_cost_per_image_token"` // 图片输出 token 价格
} }
// PricingRemoteClient 远程价格数据获取接口 // PricingRemoteClient 远程价格数据获取接口
...@@ -94,6 +95,7 @@ type LiteLLMRawEntry struct { ...@@ -94,6 +95,7 @@ type LiteLLMRawEntry struct {
Mode string `json:"mode"` Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"` SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage *float64 `json:"output_cost_per_image"` OutputCostPerImage *float64 `json:"output_cost_per_image"`
OutputCostPerImageToken *float64 `json:"output_cost_per_image_token"`
} }
// PricingService 动态价格服务 // PricingService 动态价格服务
...@@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel ...@@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if entry.OutputCostPerImage != nil { if entry.OutputCostPerImage != nil {
pricing.OutputCostPerImage = *entry.OutputCostPerImage pricing.OutputCostPerImage = *entry.OutputCostPerImage
} }
if entry.OutputCostPerImageToken != nil {
pricing.OutputCostPerImageToken = *entry.OutputCostPerImageToken
}
result[modelName] = pricing result[modelName] = pricing
} }
......
...@@ -131,9 +131,9 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ ...@@ -131,9 +131,9 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
return nil, errors.New("count must be greater than 0") return nil, errors.New("count must be greater than 0")
} }
// 邀请码类型不需要数值,其他类型需要 // 邀请码类型不需要数值,其他类型需要非零值(支持负数用于退款)
if req.Type != RedeemTypeInvitation && req.Value <= 0 { if req.Type != RedeemTypeInvitation && req.Value == 0 {
return nil, errors.New("value must be greater than 0") return nil, errors.New("value must not be zero")
} }
if req.Count > 1000 { if req.Count > 1000 {
...@@ -188,8 +188,8 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error ...@@ -188,8 +188,8 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error
if code.Type == "" { if code.Type == "" {
code.Type = RedeemTypeBalance code.Type = RedeemTypeBalance
} }
if code.Type != RedeemTypeInvitation && code.Value <= 0 { if code.Type != RedeemTypeInvitation && code.Value == 0 {
return errors.New("value must be greater than 0") return errors.New("value must not be zero")
} }
if code.Status == "" { if code.Status == "" {
code.Status = StatusUnused code.Status = StatusUnused
...@@ -292,7 +292,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( ...@@ -292,7 +292,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
if err != nil { if err != nil {
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
_ = user // 使用变量避免未使用错误
// 使用数据库事务保证兑换码标记与权益发放的原子性 // 使用数据库事务保证兑换码标记与权益发放的原子性
tx, err := s.entClient.Tx(ctx) tx, err := s.entClient.Tx(ctx)
...@@ -316,31 +315,46 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( ...@@ -316,31 +315,46 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作) // 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch redeemCode.Type { switch redeemCode.Type {
case RedeemTypeBalance: case RedeemTypeBalance:
// 增加用户余额 amount := redeemCode.Value
if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil { // 负数为退款扣减,余额最低为 0
if amount < 0 && user.Balance+amount < 0 {
amount = -user.Balance
}
if err := s.userRepo.UpdateBalance(txCtx, userID, amount); err != nil {
return nil, fmt.Errorf("update user balance: %w", err) return nil, fmt.Errorf("update user balance: %w", err)
} }
case RedeemTypeConcurrency: case RedeemTypeConcurrency:
// 增加用户并发数 delta := int(redeemCode.Value)
if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil { // 负数为退款扣减,并发数最低为 0
if delta < 0 && user.Concurrency+delta < 0 {
delta = -user.Concurrency
}
if err := s.userRepo.UpdateConcurrency(txCtx, userID, delta); err != nil {
return nil, fmt.Errorf("update user concurrency: %w", err) return nil, fmt.Errorf("update user concurrency: %w", err)
} }
case RedeemTypeSubscription: case RedeemTypeSubscription:
validityDays := redeemCode.ValidityDays validityDays := redeemCode.ValidityDays
if validityDays <= 0 { if validityDays < 0 {
validityDays = 30 // 负数天数:缩短订阅,减到 0 则取消订阅
} if err := s.reduceOrCancelSubscription(txCtx, userID, *redeemCode.GroupID, -validityDays, redeemCode.Code); err != nil {
_, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{ return nil, fmt.Errorf("reduce or cancel subscription: %w", err)
UserID: userID, }
GroupID: *redeemCode.GroupID, } else {
ValidityDays: validityDays, if validityDays == 0 {
AssignedBy: 0, // 系统分配 validityDays = 30
Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code), }
}) _, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{
if err != nil { UserID: userID,
return nil, fmt.Errorf("assign or extend subscription: %w", err) GroupID: *redeemCode.GroupID,
ValidityDays: validityDays,
AssignedBy: 0, // 系统分配
Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code),
})
if err != nil {
return nil, fmt.Errorf("assign or extend subscription: %w", err)
}
} }
default: default:
...@@ -475,3 +489,51 @@ func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit ...@@ -475,3 +489,51 @@ func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit
} }
return codes, nil return codes, nil
} }
// reduceOrCancelSubscription 缩短订阅天数,剩余天数 <= 0 时取消订阅
func (s *RedeemService) reduceOrCancelSubscription(ctx context.Context, userID, groupID int64, reduceDays int, code string) error {
sub, err := s.subscriptionService.userSubRepo.GetByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return ErrSubscriptionNotFound
}
now := time.Now()
remaining := int(sub.ExpiresAt.Sub(now).Hours() / 24)
if remaining < 0 {
remaining = 0
}
notes := fmt.Sprintf("通过兑换码 %s 退款扣减 %d 天", code, reduceDays)
if remaining <= reduceDays {
// 剩余天数不足,直接取消订阅
if err := s.subscriptionService.userSubRepo.UpdateStatus(ctx, sub.ID, SubscriptionStatusExpired); err != nil {
return fmt.Errorf("cancel subscription: %w", err)
}
// 设置过期时间为当前时间
if err := s.subscriptionService.userSubRepo.ExtendExpiry(ctx, sub.ID, now); err != nil {
return fmt.Errorf("set subscription expiry: %w", err)
}
} else {
// 缩短天数
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, -reduceDays)
if err := s.subscriptionService.userSubRepo.ExtendExpiry(ctx, sub.ID, newExpiresAt); err != nil {
return fmt.Errorf("reduce subscription: %w", err)
}
}
// 追加备注
newNotes := sub.Notes
if newNotes != "" {
newNotes += "\n"
}
newNotes += notes
if err := s.subscriptionService.userSubRepo.UpdateNotes(ctx, sub.ID, newNotes); err != nil {
return fmt.Errorf("update subscription notes: %w", err)
}
// 失效缓存
s.subscriptionService.InvalidateSubCache(userID, groupID)
return nil
}
//go:build unit
package service
// testPtrFloat64 returns a pointer to the given float64 value.
func testPtrFloat64(v float64) *float64 { return &v }
// testPtrInt returns a pointer to the given int value.
func testPtrInt(v int) *int { return &v }
// testPtrString returns a pointer to the given string value.
func testPtrString(v string) *string { return &v }
// testPtrBool returns a pointer to the given bool value.
func testPtrBool(v bool) *bool { return &v }
...@@ -104,6 +104,14 @@ type UsageLog struct { ...@@ -104,6 +104,14 @@ type UsageLog struct {
// UpstreamModel is the actual model sent to the upstream provider after mapping. // UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is). // Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string UpstreamModel *string
// ChannelID 渠道 ID
ChannelID *int64
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string
// BillingTier 计费层级标签(per_request/image 模式)
BillingTier *string
// BillingMode 计费模式:token/image(sora 路径为 nil)
BillingMode *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string ServiceTier *string
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.
...@@ -126,6 +134,9 @@ type UsageLog struct { ...@@ -126,6 +134,9 @@ type UsageLog struct {
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"` CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"` CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
ImageOutputTokens int
ImageOutputCost float64
InputCost float64 InputCost float64
OutputCost float64 OutputCost float64
CacheCreationCost float64 CacheCreationCost float64
......
...@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string { ...@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
} }
return strings.TrimSpace(upstreamModel) return strings.TrimSpace(upstreamModel)
} }
func optionalInt64Ptr(v int64) *int64 {
if v == 0 {
return nil
}
return &v
}
...@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet( ...@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
ProvideScheduledTestService, ProvideScheduledTestService,
ProvideScheduledTestRunnerService, ProvideScheduledTestRunnerService,
NewGroupCapacityService, NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
) )
-- Create channels table for managing pricing channels.
-- A channel groups multiple groups together and provides custom model pricing.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 渠道表
CREATE TABLE IF NOT EXISTS channels (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
description TEXT DEFAULT '',
status VARCHAR(20) NOT NULL DEFAULT 'active',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- 渠道名称唯一索引
CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name);
CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status);
-- 渠道-分组关联表(每个分组只能属于一个渠道)
CREATE TABLE IF NOT EXISTS channel_groups (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id);
CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id);
-- 渠道模型定价表(一条定价可绑定多个模型)
CREATE TABLE IF NOT EXISTS channel_model_pricing (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
models JSONB NOT NULL DEFAULT '[]',
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
image_output_price NUMERIC(20,8),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id);
COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价';
COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道';
COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致';
COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]';
COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格(USD),NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格(USD),NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格,NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格,NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格(Gemini Image 等),NULL 表示使用默认';
-- Extend channel_model_pricing with billing_mode and add context-interval child table.
-- Supports three billing modes: token (per-token with context intervals),
-- per_request (per-request with context-size tiers), and image (per-image).
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 1. 为 channel_model_pricing 添加 billing_mode 列
ALTER TABLE channel_model_pricing
ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token';
COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)';
-- 2. 创建区间定价子表
CREATE TABLE IF NOT EXISTS channel_pricing_intervals (
id BIGSERIAL PRIMARY KEY,
pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE,
min_tokens INT NOT NULL DEFAULT 0,
max_tokens INT,
tier_label VARCHAR(50),
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
per_request_price NUMERIC(20,12),
sort_order INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id
ON channel_pricing_intervals (pricing_id);
COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层';
COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界(含),token 模式使用';
COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界(不含),NULL 表示无上限';
COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD)';
COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价';
COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价';
COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价';
COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价';
COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格';
-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目
INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order)
SELECT
cmp.id,
0,
NULL,
cmp.input_price,
cmp.output_price,
cmp.cache_write_price,
cmp.cache_read_price,
0
FROM channel_model_pricing cmp
WHERE cmp.billing_mode = 'token'
AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL
OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL)
AND NOT EXISTS (
SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id
);
-- 4. 迁移 image_output_price 为 image 模式的区间条目
-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留
-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
ALTER TABLE channels ADD COLUMN IF NOT EXISTS model_mapping JSONB DEFAULT '{}';
COMMENT ON COLUMN channels.model_mapping IS '渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}';
-- Add billing_model_source to channels (controls whether billing uses requested or upstream model)
ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested';
-- Add channel tracking fields to usage_logs
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500);
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50);
-- Add model restriction switch to channels
ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false;
-- Add default per_request_price to channel_model_pricing (fallback when no tier matches)
ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment