Unverified Commit bf455811 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1455 from touwaeriol/feat/channel-management

feat(channel): add channel management with multi-mode pricing and billing integration
parents b384570d e88b2890
package service
import "strings"
// resolveOpenAIForwardModel resolves the account/group mapping result for
// OpenAI-compatible forwarding. Group-level default mapping only applies when
// the account itself did not match any explicit model_mapping rule.
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil {
if defaultMappedModel != "" {
......@@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
}
return mappedModel
}
func resolveOpenAIUpstreamModel(model string) string {
if isBareGPT53CodexSparkModel(model) {
return "gpt-5.3-codex-spark"
}
return normalizeCodexModel(strings.TrimSpace(model))
}
func isBareGPT53CodexSparkModel(model string) bool {
modelID := strings.TrimSpace(model)
if modelID == "" {
return false
}
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
normalized := strings.ToLower(strings.TrimSpace(modelID))
return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark"
}
......@@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
Credentials: map[string]any{},
}
withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
if withoutDefault != "gpt-5.1" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
}
withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
if withDefault != "gpt-5.4" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4")
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4")
}
}
func TestResolveOpenAIUpstreamModel(t *testing.T) {
func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
" openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark": "gpt-5.3-codex",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3": "gpt-5.3-codex",
}
for input, expected := range cases {
if got := resolveOpenAIUpstreamModel(input); got != expected {
t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected)
if got := normalizeCodexModel(input); got != expected {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", input, got, expected)
}
}
}
......@@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
normalized = next
}
upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
upstreamModel := normalizeCodexModel(account.GetMappedModel(originalModel))
if upstreamModel != originalModel {
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
if setErr != nil {
......@@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
mappedModel := ""
var mappedModelBytes []byte
if originalModel != "" {
mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
mappedModel = normalizeCodexModel(account.GetMappedModel(originalModel))
needModelReplace = mappedModel != "" && mappedModel != originalModel
if needModelReplace {
mappedModelBytes = []byte(mappedModel)
......
......@@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
......@@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available")
}
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制
default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
}
......
......@@ -71,6 +71,7 @@ type LiteLLMModelPricing struct {
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
OutputCostPerImageToken float64 `json:"output_cost_per_image_token"` // 图片输出 token 价格
}
// PricingRemoteClient 远程价格数据获取接口
......@@ -94,6 +95,7 @@ type LiteLLMRawEntry struct {
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage *float64 `json:"output_cost_per_image"`
OutputCostPerImageToken *float64 `json:"output_cost_per_image_token"`
}
// PricingService 动态价格服务
......@@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if entry.OutputCostPerImage != nil {
pricing.OutputCostPerImage = *entry.OutputCostPerImage
}
if entry.OutputCostPerImageToken != nil {
pricing.OutputCostPerImageToken = *entry.OutputCostPerImageToken
}
result[modelName] = pricing
}
......
//go:build unit
package service
// testPtrFloat64 returns a pointer to the given float64 value.
func testPtrFloat64(v float64) *float64 { return &v }
// testPtrInt returns a pointer to the given int value.
func testPtrInt(v int) *int { return &v }
// testPtrString returns a pointer to the given string value.
func testPtrString(v string) *string { return &v }
// testPtrBool returns a pointer to the given bool value.
func testPtrBool(v bool) *bool { return &v }
......@@ -104,6 +104,14 @@ type UsageLog struct {
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string
// ChannelID 渠道 ID
ChannelID *int64
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string
// BillingTier 计费层级标签(per_request/image 模式)
BillingTier *string
// BillingMode 计费模式:token/image(sora 路径为 nil)
BillingMode *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string
// ReasoningEffort is the request's reasoning effort level.
......@@ -126,6 +134,9 @@ type UsageLog struct {
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
ImageOutputTokens int
ImageOutputCost float64
InputCost float64
OutputCost float64
CacheCreationCost float64
......
......@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
}
return strings.TrimSpace(upstreamModel)
}
func optionalInt64Ptr(v int64) *int64 {
if v == 0 {
return nil
}
return &v
}
......@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
ProvideScheduledTestService,
ProvideScheduledTestRunnerService,
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
)
-- Create channels table for managing pricing channels.
-- A channel groups multiple groups together and provides custom model pricing.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 渠道表
CREATE TABLE IF NOT EXISTS channels (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
description TEXT DEFAULT '',
status VARCHAR(20) NOT NULL DEFAULT 'active',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- 渠道名称唯一索引
CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name);
CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status);
-- 渠道-分组关联表(每个分组只能属于一个渠道)
CREATE TABLE IF NOT EXISTS channel_groups (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id);
CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id);
-- 渠道模型定价表(一条定价可绑定多个模型)
CREATE TABLE IF NOT EXISTS channel_model_pricing (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
models JSONB NOT NULL DEFAULT '[]',
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
image_output_price NUMERIC(20,8),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id);
COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价';
COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道';
COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致';
COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]';
COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格(USD),NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格(USD),NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格,NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格,NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格(Gemini Image 等),NULL 表示使用默认';
-- Extend channel_model_pricing with billing_mode and add context-interval child table.
-- Supports three billing modes: token (per-token with context intervals),
-- per_request (per-request with context-size tiers), and image (per-image).
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 1. 为 channel_model_pricing 添加 billing_mode 列
ALTER TABLE channel_model_pricing
ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token';
COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)';
-- 2. 创建区间定价子表
CREATE TABLE IF NOT EXISTS channel_pricing_intervals (
id BIGSERIAL PRIMARY KEY,
pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE,
min_tokens INT NOT NULL DEFAULT 0,
max_tokens INT,
tier_label VARCHAR(50),
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
per_request_price NUMERIC(20,12),
sort_order INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id
ON channel_pricing_intervals (pricing_id);
COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层';
COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界(含),token 模式使用';
COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界(不含),NULL 表示无上限';
COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD)';
COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价';
COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价';
COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价';
COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价';
COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格';
-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目
INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order)
SELECT
cmp.id,
0,
NULL,
cmp.input_price,
cmp.output_price,
cmp.cache_write_price,
cmp.cache_read_price,
0
FROM channel_model_pricing cmp
WHERE cmp.billing_mode = 'token'
AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL
OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL)
AND NOT EXISTS (
SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id
);
-- 4. 迁移 image_output_price 为 image 模式的区间条目
-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留
-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
ALTER TABLE channels ADD COLUMN IF NOT EXISTS model_mapping JSONB DEFAULT '{}';
COMMENT ON COLUMN channels.model_mapping IS '渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}';
-- Add billing_model_source to channels (controls whether billing uses requested or upstream model)
ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested';
-- Add channel tracking fields to usage_logs
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500);
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50);
-- Add model restriction switch to channels
ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false;
-- Add default per_request_price to channel_model_pricing (fallback when no tier matches)
ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10);
-- 086_channel_platform_pricing.sql
-- 渠道按平台维度:model_pricing 加 platform 列,model_mapping 改为嵌套格式
-- 1. channel_model_pricing 加 platform 列
ALTER TABLE channel_model_pricing
ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_platform
ON channel_model_pricing (platform);
-- 2. model_mapping: 从扁平 {"src":"dst"} 迁移为嵌套 {"anthropic":{"src":"dst"}}
-- 仅迁移非空、非 '{}' 的旧格式数据(通过检查第一个 value 是否为字符串来判断是否为旧格式)
UPDATE channels
SET model_mapping = jsonb_build_object('anthropic', model_mapping)
WHERE model_mapping IS NOT NULL
AND model_mapping::text NOT IN ('{}', 'null', '')
AND NOT EXISTS (
SELECT 1 FROM jsonb_each(model_mapping) AS kv
WHERE jsonb_typeof(kv.value) = 'object'
LIMIT 1
);
-- Add billing_mode to usage_logs (records the billing mode: token/per_request/image)
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20);
-- Change default billing_model_source for new channels to 'channel_mapped'
-- Existing channels keep their current setting (no UPDATE on existing rows)
ALTER TABLE channels ALTER COLUMN billing_model_source SET DEFAULT 'channel_mapped';
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_tokens INTEGER NOT NULL DEFAULT 0;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0;
/**
* Admin Channels API endpoints
* Handles channel management for administrators
*/
import { apiClient } from '../client'
export type BillingMode = 'token' | 'per_request' | 'image'
export interface PricingInterval {
id?: number
min_tokens: number
max_tokens: number | null
tier_label: string
input_price: number | null
output_price: number | null
cache_write_price: number | null
cache_read_price: number | null
per_request_price: number | null
sort_order: number
}
export interface ChannelModelPricing {
id?: number
platform: string
models: string[]
billing_mode: BillingMode
input_price: number | null
output_price: number | null
cache_write_price: number | null
cache_read_price: number | null
image_output_price: number | null
per_request_price: number | null
intervals: PricingInterval[]
}
export interface Channel {
id: number
name: string
description: string
status: string
billing_model_source: string // "requested" | "upstream"
restrict_models: boolean
group_ids: number[]
model_pricing: ChannelModelPricing[]
model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
created_at: string
updated_at: string
}
export interface CreateChannelRequest {
name: string
description?: string
group_ids?: number[]
model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, Record<string, string>>
billing_model_source?: string
restrict_models?: boolean
}
export interface UpdateChannelRequest {
name?: string
description?: string
status?: string
group_ids?: number[]
model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, Record<string, string>>
billing_model_source?: string
restrict_models?: boolean
}
interface PaginatedResponse<T> {
items: T[]
total: number
}
/**
* List channels with pagination
*/
export async function list(
page: number = 1,
pageSize: number = 20,
filters?: {
status?: string
search?: string
},
options?: { signal?: AbortSignal }
): Promise<PaginatedResponse<Channel>> {
const { data } = await apiClient.get<PaginatedResponse<Channel>>('/admin/channels', {
params: {
page,
page_size: pageSize,
...filters
},
signal: options?.signal
})
return data
}
/**
* Get channel by ID
*/
export async function getById(id: number): Promise<Channel> {
const { data } = await apiClient.get<Channel>(`/admin/channels/${id}`)
return data
}
/**
* Create a new channel
*/
export async function create(req: CreateChannelRequest): Promise<Channel> {
const { data } = await apiClient.post<Channel>('/admin/channels', req)
return data
}
/**
* Update a channel
*/
export async function update(id: number, req: UpdateChannelRequest): Promise<Channel> {
const { data } = await apiClient.put<Channel>(`/admin/channels/${id}`, req)
return data
}
/**
* Delete a channel
*/
export async function remove(id: number): Promise<void> {
await apiClient.delete(`/admin/channels/${id}`)
}
export interface ModelDefaultPricing {
found: boolean
input_price?: number // per-token price
output_price?: number
cache_write_price?: number
cache_read_price?: number
image_output_price?: number
}
export async function getModelDefaultPricing(model: string): Promise<ModelDefaultPricing> {
const { data } = await apiClient.get<ModelDefaultPricing>('/admin/channels/model-pricing', {
params: { model }
})
return data
}
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing }
export default channelsAPI
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment