Commit 7b83d6e7 authored by 陈曦's avatar 陈曦
Browse files

Merge remote-tracking branch 'upstream/main'

parents daa2e6df dbb248df
......@@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
// 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family.
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
upstreamModel := normalizeCodexModel(billingModel)
promptCacheKey = strings.TrimSpace(promptCacheKey)
compatPromptCacheInjected := false
......
......@@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
// 3. Model mapping
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
upstreamModel := normalizeCodexModel(billingModel)
responsesReq.Model = upstreamModel
logger.L().Debug("openai messages: model mapping applied",
......
......@@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
&DeferredService{},
nil,
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
......
......@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net/http"
"sort"
......@@ -204,6 +205,7 @@ type OpenAIUsage struct {
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
}
// OpenAIForwardResult represents the result of forwarding
......@@ -322,6 +324,8 @@ type OpenAIGatewayService struct {
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
channelService *ChannelService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
......@@ -357,6 +361,8 @@ func NewOpenAIGatewayService(
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
channelService *ChannelService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
......@@ -384,6 +390,8 @@ func NewOpenAIGatewayService(
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
channelService: channelService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
......@@ -391,6 +399,74 @@ func NewOpenAIGatewayService(
return svc
}
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService)
func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService)
func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 模型限制检查已移至调度阶段,restricted 始终返回 false。
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
}
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
}
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
}
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
if billingModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
if upstreamModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
}
func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil {
slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err)
return false
}
if ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
}
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return ReplaceModelInBody(body, newModel)
}
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
if s != nil && s.codexSnapshotThrottle != nil {
return s.codexSnapshotThrottle
......@@ -1125,6 +1201,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
}
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 1. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
......@@ -1140,7 +1223,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs)
selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs)
if selected == nil {
if requestedModel != "" {
......@@ -1206,6 +1289,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
......@@ -1218,8 +1306,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
//
// selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
for i := range accounts {
acc := &accounts[i]
......@@ -1238,6 +1327,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
......@@ -1289,7 +1381,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
cfg := s.schedulingConfig()
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
......@@ -1365,6 +1465,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
......@@ -1410,6 +1512,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
candidates = append(candidates, acc)
}
......@@ -1434,6 +1539,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
......@@ -1488,6 +1596,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
......@@ -1510,6 +1621,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
return &AccountSelectionResult{
Account: fresh,
WaitPlan: &AccountWaitPlan{
......@@ -1825,7 +1939,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if model, ok := reqBody["model"].(string); ok {
upstreamModel = resolveOpenAIUpstreamModel(model)
upstreamModel = normalizeCodexModel(model)
if upstreamModel != "" && upstreamModel != model {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, upstreamModel, account.Name, account.Type, isCodexCLI)
......@@ -4110,6 +4224,7 @@ type OpenAIRecordUsageInput struct {
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater
ChannelUsageFields
}
// RecordUsage records usage and deducts balance
......@@ -4140,10 +4255,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
// Get rate multiplier
multiplier := s.cfg.Default.RateMultiplier
multiplier := 1.0
if s.cfg != nil {
multiplier = s.cfg.Default.RateMultiplier
}
if apiKey.GroupID != nil && apiKey.Group != nil {
resolver := s.userGroupRateResolver
if resolver == nil {
......@@ -4152,12 +4271,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
var cost *CostBreakdown
var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
serviceTier := ""
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
}
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
if err != nil {
cost = &CostBreakdown{ActualCost: 0}
}
......@@ -4173,13 +4317,20 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
// 确定 RequestedModel(渠道映射前的原始模型)
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
......@@ -4189,20 +4340,35 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
OpenAIWSMode: result.OpenAIWSMode,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
if cost != nil {
usageLog.InputCost = cost.InputCost
usageLog.OutputCost = cost.OutputCost
usageLog.ImageOutputCost = cost.ImageOutputCost
usageLog.CacheCreationCost = cost.CacheCreationCost
usageLog.CacheReadCost = cost.CacheReadCost
usageLog.TotalCost = cost.TotalCost
usageLog.ActualCost = cost.ActualCost
}
usageLog.RateMultiplier = multiplier
usageLog.AccountRateMultiplier = &accountRateMultiplier
usageLog.BillingType = billingType
usageLog.Stream = result.Stream
usageLog.OpenAIWSMode = result.OpenAIWSMode
usageLog.DurationMs = &durationMs
usageLog.FirstTokenMs = result.FirstTokenMs
usageLog.CreatedAt = time.Now()
// 设置渠道信息
usageLog.ChannelID = optionalInt64Ptr(input.ChannelID)
usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain)
// 设置计费模式
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
}
// 添加 UserAgent
if input.UserAgent != "" {
......
package service
import "strings"
// resolveOpenAIForwardModel resolves the account/group mapping result for
// OpenAI-compatible forwarding. Group-level default mapping only applies when
// the account itself did not match any explicit model_mapping rule.
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil {
if defaultMappedModel != "" {
......@@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
}
return mappedModel
}
func resolveOpenAIUpstreamModel(model string) string {
if isBareGPT53CodexSparkModel(model) {
return "gpt-5.3-codex-spark"
}
return normalizeCodexModel(strings.TrimSpace(model))
}
func isBareGPT53CodexSparkModel(model string) bool {
modelID := strings.TrimSpace(model)
if modelID == "" {
return false
}
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
normalized := strings.ToLower(strings.TrimSpace(modelID))
return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark"
}
......@@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
Credentials: map[string]any{},
}
withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
if withoutDefault != "gpt-5.1" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
}
withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
if withDefault != "gpt-5.4" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4")
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4")
}
}
func TestResolveOpenAIUpstreamModel(t *testing.T) {
func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
" openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark": "gpt-5.3-codex",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3": "gpt-5.3-codex",
}
for input, expected := range cases {
if got := resolveOpenAIUpstreamModel(input); got != expected {
t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected)
if got := normalizeCodexModel(input); got != expected {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", input, got, expected)
}
}
}
......@@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
normalized = next
}
upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
upstreamModel := normalizeCodexModel(account.GetMappedModel(originalModel))
if upstreamModel != originalModel {
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
if setErr != nil {
......@@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
mappedModel := ""
var mappedModelBytes []byte
if originalModel != "" {
mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
mappedModel = normalizeCodexModel(account.GetMappedModel(originalModel))
needModelReplace = mappedModel != "" && mappedModel != originalModel
if needModelReplace {
mappedModelBytes = []byte(mappedModel)
......
......@@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
......@@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available")
}
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制
default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
}
......
......@@ -71,6 +71,7 @@ type LiteLLMModelPricing struct {
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
OutputCostPerImageToken float64 `json:"output_cost_per_image_token"` // 图片输出 token 价格
}
// PricingRemoteClient 远程价格数据获取接口
......@@ -94,6 +95,7 @@ type LiteLLMRawEntry struct {
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage *float64 `json:"output_cost_per_image"`
OutputCostPerImageToken *float64 `json:"output_cost_per_image_token"`
}
// PricingService 动态价格服务
......@@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if entry.OutputCostPerImage != nil {
pricing.OutputCostPerImage = *entry.OutputCostPerImage
}
if entry.OutputCostPerImageToken != nil {
pricing.OutputCostPerImageToken = *entry.OutputCostPerImageToken
}
result[modelName] = pricing
}
......
......@@ -131,9 +131,9 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
return nil, errors.New("count must be greater than 0")
}
// 邀请码类型不需要数值,其他类型需要
if req.Type != RedeemTypeInvitation && req.Value <= 0 {
return nil, errors.New("value must be greater than 0")
// 邀请码类型不需要数值,其他类型需要非零值(支持负数用于退款)
if req.Type != RedeemTypeInvitation && req.Value == 0 {
return nil, errors.New("value must not be zero")
}
if req.Count > 1000 {
......@@ -188,8 +188,8 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error
if code.Type == "" {
code.Type = RedeemTypeBalance
}
if code.Type != RedeemTypeInvitation && code.Value <= 0 {
return errors.New("value must be greater than 0")
if code.Type != RedeemTypeInvitation && code.Value == 0 {
return errors.New("value must not be zero")
}
if code.Status == "" {
code.Status = StatusUnused
......@@ -292,7 +292,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
_ = user // 使用变量避免未使用错误
// 使用数据库事务保证兑换码标记与权益发放的原子性
tx, err := s.entClient.Tx(ctx)
......@@ -316,20 +315,34 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch redeemCode.Type {
case RedeemTypeBalance:
// 增加用户余额
if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil {
amount := redeemCode.Value
// 负数为退款扣减,余额最低为 0
if amount < 0 && user.Balance+amount < 0 {
amount = -user.Balance
}
if err := s.userRepo.UpdateBalance(txCtx, userID, amount); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
case RedeemTypeConcurrency:
// 增加用户并发数
if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil {
delta := int(redeemCode.Value)
// 负数为退款扣减,并发数最低为 0
if delta < 0 && user.Concurrency+delta < 0 {
delta = -user.Concurrency
}
if err := s.userRepo.UpdateConcurrency(txCtx, userID, delta); err != nil {
return nil, fmt.Errorf("update user concurrency: %w", err)
}
case RedeemTypeSubscription:
validityDays := redeemCode.ValidityDays
if validityDays <= 0 {
if validityDays < 0 {
// 负数天数:缩短订阅,减到 0 则取消订阅
if err := s.reduceOrCancelSubscription(txCtx, userID, *redeemCode.GroupID, -validityDays, redeemCode.Code); err != nil {
return nil, fmt.Errorf("reduce or cancel subscription: %w", err)
}
} else {
if validityDays == 0 {
validityDays = 30
}
_, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{
......@@ -342,6 +355,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
if err != nil {
return nil, fmt.Errorf("assign or extend subscription: %w", err)
}
}
default:
return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type)
......@@ -475,3 +489,51 @@ func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit
}
return codes, nil
}
// reduceOrCancelSubscription 缩短订阅天数,剩余天数 <= 0 时取消订阅
func (s *RedeemService) reduceOrCancelSubscription(ctx context.Context, userID, groupID int64, reduceDays int, code string) error {
sub, err := s.subscriptionService.userSubRepo.GetByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return ErrSubscriptionNotFound
}
now := time.Now()
remaining := int(sub.ExpiresAt.Sub(now).Hours() / 24)
if remaining < 0 {
remaining = 0
}
notes := fmt.Sprintf("通过兑换码 %s 退款扣减 %d 天", code, reduceDays)
if remaining <= reduceDays {
// 剩余天数不足,直接取消订阅
if err := s.subscriptionService.userSubRepo.UpdateStatus(ctx, sub.ID, SubscriptionStatusExpired); err != nil {
return fmt.Errorf("cancel subscription: %w", err)
}
// 设置过期时间为当前时间
if err := s.subscriptionService.userSubRepo.ExtendExpiry(ctx, sub.ID, now); err != nil {
return fmt.Errorf("set subscription expiry: %w", err)
}
} else {
// 缩短天数
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, -reduceDays)
if err := s.subscriptionService.userSubRepo.ExtendExpiry(ctx, sub.ID, newExpiresAt); err != nil {
return fmt.Errorf("reduce subscription: %w", err)
}
}
// 追加备注
newNotes := sub.Notes
if newNotes != "" {
newNotes += "\n"
}
newNotes += notes
if err := s.subscriptionService.userSubRepo.UpdateNotes(ctx, sub.ID, newNotes); err != nil {
return fmt.Errorf("update subscription notes: %w", err)
}
// 失效缓存
s.subscriptionService.InvalidateSubCache(userID, groupID)
return nil
}
//go:build unit
package service
// testPtrFloat64 returns a pointer to the given float64 value.
func testPtrFloat64(v float64) *float64 { return &v }
// testPtrInt returns a pointer to the given int value.
func testPtrInt(v int) *int { return &v }
// testPtrString returns a pointer to the given string value.
func testPtrString(v string) *string { return &v }
// testPtrBool returns a pointer to the given bool value.
func testPtrBool(v bool) *bool { return &v }
......@@ -104,6 +104,14 @@ type UsageLog struct {
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string
// ChannelID 渠道 ID
ChannelID *int64
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string
// BillingTier 计费层级标签(per_request/image 模式)
BillingTier *string
// BillingMode 计费模式:token/image(sora 路径为 nil)
BillingMode *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string
// ReasoningEffort is the request's reasoning effort level.
......@@ -126,6 +134,9 @@ type UsageLog struct {
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
ImageOutputTokens int
ImageOutputCost float64
InputCost float64
OutputCost float64
CacheCreationCost float64
......
......@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
}
return strings.TrimSpace(upstreamModel)
}
func optionalInt64Ptr(v int64) *int64 {
if v == 0 {
return nil
}
return &v
}
......@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
ProvideScheduledTestService,
ProvideScheduledTestRunnerService,
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
)
-- Create channels table for managing pricing channels.
-- A channel groups multiple groups together and provides custom model pricing.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 渠道表
CREATE TABLE IF NOT EXISTS channels (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
description TEXT DEFAULT '',
status VARCHAR(20) NOT NULL DEFAULT 'active',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- 渠道名称唯一索引
CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name);
CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status);
-- 渠道-分组关联表(每个分组只能属于一个渠道)
CREATE TABLE IF NOT EXISTS channel_groups (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id);
CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id);
-- 渠道模型定价表(一条定价可绑定多个模型)
CREATE TABLE IF NOT EXISTS channel_model_pricing (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
models JSONB NOT NULL DEFAULT '[]',
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
image_output_price NUMERIC(20,8),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id);
COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价';
COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道';
COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致';
COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]';
COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格(USD),NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格(USD),NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格,NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格,NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格(Gemini Image 等),NULL 表示使用默认';
-- Extend channel_model_pricing with billing_mode and add context-interval child table.
-- Supports three billing modes: token (per-token with context intervals),
-- per_request (per-request with context-size tiers), and image (per-image).
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 1. 为 channel_model_pricing 添加 billing_mode 列
ALTER TABLE channel_model_pricing
ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token';
COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)';
-- 2. 创建区间定价子表
CREATE TABLE IF NOT EXISTS channel_pricing_intervals (
id BIGSERIAL PRIMARY KEY,
pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE,
min_tokens INT NOT NULL DEFAULT 0,
max_tokens INT,
tier_label VARCHAR(50),
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
per_request_price NUMERIC(20,12),
sort_order INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id
ON channel_pricing_intervals (pricing_id);
COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层';
COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界(含),token 模式使用';
COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界(不含),NULL 表示无上限';
COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD)';
COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价';
COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价';
COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价';
COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价';
COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格';
-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目
INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order)
SELECT
cmp.id,
0,
NULL,
cmp.input_price,
cmp.output_price,
cmp.cache_write_price,
cmp.cache_read_price,
0
FROM channel_model_pricing cmp
WHERE cmp.billing_mode = 'token'
AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL
OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL)
AND NOT EXISTS (
SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id
);
-- 4. 迁移 image_output_price 为 image 模式的区间条目
-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留
-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
ALTER TABLE channels ADD COLUMN IF NOT EXISTS model_mapping JSONB DEFAULT '{}';
COMMENT ON COLUMN channels.model_mapping IS '渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}';
-- Add billing_model_source to channels (controls whether billing uses requested or upstream model)
ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested';
-- Add channel tracking fields to usage_logs
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500);
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50);
-- Add model restriction switch to channels
ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false;
-- Add default per_request_price to channel_model_pricing (fallback when no tier matches)
ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment