Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
91c9b8d0
Commit
91c9b8d0
authored
Apr 04, 2026
by
erio
Browse files
feat(channel): 渠道管理系统 — 多模式定价 + 统一计费解析
Cherry-picked from release/custom-0.1.106: a9117600
parent
b384570d
Changes
27
Hide whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
91c9b8d0
...
...
@@ -49,6 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache
:=
repository
.
NewRefreshTokenCache
(
redisClient
)
settingRepository
:=
repository
.
NewSettingRepository
(
client
)
groupRepository
:=
repository
.
NewGroupRepository
(
client
,
db
)
channelRepository
:=
repository
.
NewChannelRepository
(
db
)
settingService
:=
service
.
ProvideSettingService
(
settingRepository
,
groupRepository
,
configConfig
)
emailCache
:=
repository
.
NewEmailCache
(
redisClient
)
emailService
:=
service
.
NewEmailService
(
settingRepository
,
emailCache
)
...
...
@@ -175,7 +176,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
claudeTokenProvider
:=
service
.
ProvideClaudeTokenProvider
(
accountRepository
,
geminiTokenCache
,
oAuthService
,
oauthRefreshAPI
)
digestSessionStore
:=
service
.
NewDigestSessionStore
()
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
usageBillingRepository
,
userRepository
,
userSubscriptionRepository
,
userGroupRateRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
,
sessionLimitCache
,
rpmCache
,
digestSessionStore
,
settingService
,
tlsFingerprintProfileService
)
channelService
:=
service
.
NewChannelService
(
channelRepository
,
apiKeyAuthCacheInvalidator
)
modelPricingResolver
:=
service
.
NewModelPricingResolver
(
channelService
,
billingService
)
_
=
modelPricingResolver
// Phase 4: 已注册,后续 Gateway 迁移时使用
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
usageBillingRepository
,
userRepository
,
userSubscriptionRepository
,
userGroupRateRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
,
sessionLimitCache
,
rpmCache
,
digestSessionStore
,
settingService
,
tlsFingerprintProfileService
,
channelService
)
openAITokenProvider
:=
service
.
ProvideOpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
,
oauthRefreshAPI
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
usageBillingRepository
,
userRepository
,
userSubscriptionRepository
,
userGroupRateRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
...
...
@@ -213,7 +217,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository
:=
repository
.
NewScheduledTestResultRepository
(
db
)
scheduledTestService
:=
service
.
ProvideScheduledTestService
(
scheduledTestPlanRepository
,
scheduledTestResultRepository
)
scheduledTestHandler
:=
admin
.
NewScheduledTestHandler
(
scheduledTestService
)
adminHandlers
:=
handler
.
ProvideAdminHandlers
(
dashboardHandler
,
adminUserHandler
,
groupHandler
,
accountHandler
,
adminAnnouncementHandler
,
dataManagementHandler
,
backupHandler
,
oAuthHandler
,
openAIOAuthHandler
,
geminiOAuthHandler
,
antigravityOAuthHandler
,
proxyHandler
,
adminRedeemHandler
,
promoHandler
,
settingHandler
,
opsHandler
,
systemHandler
,
adminSubscriptionHandler
,
adminUsageHandler
,
userAttributeHandler
,
errorPassthroughHandler
,
tlsFingerprintProfileHandler
,
adminAPIKeyHandler
,
scheduledTestHandler
)
channelHandler
:=
admin
.
NewChannelHandler
(
channelService
)
adminHandlers
:=
handler
.
ProvideAdminHandlers
(
dashboardHandler
,
adminUserHandler
,
groupHandler
,
accountHandler
,
adminAnnouncementHandler
,
dataManagementHandler
,
backupHandler
,
oAuthHandler
,
openAIOAuthHandler
,
geminiOAuthHandler
,
antigravityOAuthHandler
,
proxyHandler
,
adminRedeemHandler
,
promoHandler
,
settingHandler
,
opsHandler
,
systemHandler
,
adminSubscriptionHandler
,
adminUsageHandler
,
userAttributeHandler
,
errorPassthroughHandler
,
tlsFingerprintProfileHandler
,
adminAPIKeyHandler
,
scheduledTestHandler
,
channelHandler
)
usageRecordWorkerPool
:=
service
.
NewUsageRecordWorkerPool
(
configConfig
)
userMsgQueueCache
:=
repository
.
NewUserMsgQueueCache
(
redisClient
)
userMessageQueueService
:=
service
.
ProvideUserMessageQueueService
(
userMsgQueueCache
,
rpmCache
,
configConfig
)
...
...
backend/internal/handler/admin/channel_handler.go
0 → 100644
View file @
91c9b8d0
package
admin
import
(
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ChannelHandler handles admin channel management
type
ChannelHandler
struct
{
channelService
*
service
.
ChannelService
}
// NewChannelHandler creates a new admin channel handler
func
NewChannelHandler
(
channelService
*
service
.
ChannelService
)
*
ChannelHandler
{
return
&
ChannelHandler
{
channelService
:
channelService
}
}
// --- Request / Response types ---
type
createChannelRequest
struct
{
Name
string
`json:"name" binding:"required,max=100"`
Description
string
`json:"description"`
GroupIDs
[]
int64
`json:"group_ids"`
ModelPricing
[]
channelModelPricingRequest
`json:"model_pricing"`
}
type
updateChannelRequest
struct
{
Name
string
`json:"name" binding:"omitempty,max=100"`
Description
*
string
`json:"description"`
Status
string
`json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs
*
[]
int64
`json:"group_ids"`
ModelPricing
*
[]
channelModelPricingRequest
`json:"model_pricing"`
}
type
channelModelPricingRequest
struct
{
Models
[]
string
`json:"models" binding:"required,min=1,max=100"`
BillingMode
string
`json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
InputPrice
*
float64
`json:"input_price" binding:"omitempty,min=0"`
OutputPrice
*
float64
`json:"output_price" binding:"omitempty,min=0"`
CacheWritePrice
*
float64
`json:"cache_write_price" binding:"omitempty,min=0"`
CacheReadPrice
*
float64
`json:"cache_read_price" binding:"omitempty,min=0"`
ImageOutputPrice
*
float64
`json:"image_output_price" binding:"omitempty,min=0"`
Intervals
[]
pricingIntervalRequest
`json:"intervals"`
}
type
pricingIntervalRequest
struct
{
MinTokens
int
`json:"min_tokens"`
MaxTokens
*
int
`json:"max_tokens"`
TierLabel
string
`json:"tier_label"`
InputPrice
*
float64
`json:"input_price"`
OutputPrice
*
float64
`json:"output_price"`
CacheWritePrice
*
float64
`json:"cache_write_price"`
CacheReadPrice
*
float64
`json:"cache_read_price"`
PerRequestPrice
*
float64
`json:"per_request_price"`
SortOrder
int
`json:"sort_order"`
}
type
channelResponse
struct
{
ID
int64
`json:"id"`
Name
string
`json:"name"`
Description
string
`json:"description"`
Status
string
`json:"status"`
GroupIDs
[]
int64
`json:"group_ids"`
ModelPricing
[]
channelModelPricingResponse
`json:"model_pricing"`
CreatedAt
string
`json:"created_at"`
UpdatedAt
string
`json:"updated_at"`
}
type
channelModelPricingResponse
struct
{
ID
int64
`json:"id"`
Models
[]
string
`json:"models"`
BillingMode
string
`json:"billing_mode"`
InputPrice
*
float64
`json:"input_price"`
OutputPrice
*
float64
`json:"output_price"`
CacheWritePrice
*
float64
`json:"cache_write_price"`
CacheReadPrice
*
float64
`json:"cache_read_price"`
ImageOutputPrice
*
float64
`json:"image_output_price"`
Intervals
[]
pricingIntervalResponse
`json:"intervals"`
}
type
pricingIntervalResponse
struct
{
ID
int64
`json:"id"`
MinTokens
int
`json:"min_tokens"`
MaxTokens
*
int
`json:"max_tokens"`
TierLabel
string
`json:"tier_label,omitempty"`
InputPrice
*
float64
`json:"input_price"`
OutputPrice
*
float64
`json:"output_price"`
CacheWritePrice
*
float64
`json:"cache_write_price"`
CacheReadPrice
*
float64
`json:"cache_read_price"`
PerRequestPrice
*
float64
`json:"per_request_price"`
SortOrder
int
`json:"sort_order"`
}
func
channelToResponse
(
ch
*
service
.
Channel
)
*
channelResponse
{
if
ch
==
nil
{
return
nil
}
resp
:=
&
channelResponse
{
ID
:
ch
.
ID
,
Name
:
ch
.
Name
,
Description
:
ch
.
Description
,
Status
:
ch
.
Status
,
GroupIDs
:
ch
.
GroupIDs
,
CreatedAt
:
ch
.
CreatedAt
.
Format
(
"2006-01-02T15:04:05Z"
),
UpdatedAt
:
ch
.
UpdatedAt
.
Format
(
"2006-01-02T15:04:05Z"
),
}
if
resp
.
GroupIDs
==
nil
{
resp
.
GroupIDs
=
[]
int64
{}
}
resp
.
ModelPricing
=
make
([]
channelModelPricingResponse
,
0
,
len
(
ch
.
ModelPricing
))
for
_
,
p
:=
range
ch
.
ModelPricing
{
models
:=
p
.
Models
if
models
==
nil
{
models
=
[]
string
{}
}
billingMode
:=
string
(
p
.
BillingMode
)
if
billingMode
==
""
{
billingMode
=
"token"
}
intervals
:=
make
([]
pricingIntervalResponse
,
0
,
len
(
p
.
Intervals
))
for
_
,
iv
:=
range
p
.
Intervals
{
intervals
=
append
(
intervals
,
pricingIntervalResponse
{
ID
:
iv
.
ID
,
MinTokens
:
iv
.
MinTokens
,
MaxTokens
:
iv
.
MaxTokens
,
TierLabel
:
iv
.
TierLabel
,
InputPrice
:
iv
.
InputPrice
,
OutputPrice
:
iv
.
OutputPrice
,
CacheWritePrice
:
iv
.
CacheWritePrice
,
CacheReadPrice
:
iv
.
CacheReadPrice
,
PerRequestPrice
:
iv
.
PerRequestPrice
,
SortOrder
:
iv
.
SortOrder
,
})
}
resp
.
ModelPricing
=
append
(
resp
.
ModelPricing
,
channelModelPricingResponse
{
ID
:
p
.
ID
,
Models
:
models
,
BillingMode
:
billingMode
,
InputPrice
:
p
.
InputPrice
,
OutputPrice
:
p
.
OutputPrice
,
CacheWritePrice
:
p
.
CacheWritePrice
,
CacheReadPrice
:
p
.
CacheReadPrice
,
ImageOutputPrice
:
p
.
ImageOutputPrice
,
Intervals
:
intervals
,
})
}
return
resp
}
func
pricingRequestToService
(
reqs
[]
channelModelPricingRequest
)
[]
service
.
ChannelModelPricing
{
result
:=
make
([]
service
.
ChannelModelPricing
,
0
,
len
(
reqs
))
for
_
,
r
:=
range
reqs
{
billingMode
:=
service
.
BillingMode
(
r
.
BillingMode
)
if
billingMode
==
""
{
billingMode
=
service
.
BillingModeToken
}
intervals
:=
make
([]
service
.
PricingInterval
,
0
,
len
(
r
.
Intervals
))
for
_
,
iv
:=
range
r
.
Intervals
{
intervals
=
append
(
intervals
,
service
.
PricingInterval
{
MinTokens
:
iv
.
MinTokens
,
MaxTokens
:
iv
.
MaxTokens
,
TierLabel
:
iv
.
TierLabel
,
InputPrice
:
iv
.
InputPrice
,
OutputPrice
:
iv
.
OutputPrice
,
CacheWritePrice
:
iv
.
CacheWritePrice
,
CacheReadPrice
:
iv
.
CacheReadPrice
,
PerRequestPrice
:
iv
.
PerRequestPrice
,
SortOrder
:
iv
.
SortOrder
,
})
}
result
=
append
(
result
,
service
.
ChannelModelPricing
{
Models
:
r
.
Models
,
BillingMode
:
billingMode
,
InputPrice
:
r
.
InputPrice
,
OutputPrice
:
r
.
OutputPrice
,
CacheWritePrice
:
r
.
CacheWritePrice
,
CacheReadPrice
:
r
.
CacheReadPrice
,
ImageOutputPrice
:
r
.
ImageOutputPrice
,
Intervals
:
intervals
,
})
}
return
result
}
// --- Handlers ---
// List handles listing channels with pagination
// GET /api/v1/admin/channels
func
(
h
*
ChannelHandler
)
List
(
c
*
gin
.
Context
)
{
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
status
:=
c
.
Query
(
"status"
)
search
:=
strings
.
TrimSpace
(
c
.
Query
(
"search"
))
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
channels
,
pag
,
err
:=
h
.
channelService
.
List
(
c
.
Request
.
Context
(),
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
},
status
,
search
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
out
:=
make
([]
*
channelResponse
,
0
,
len
(
channels
))
for
i
:=
range
channels
{
out
=
append
(
out
,
channelToResponse
(
&
channels
[
i
]))
}
response
.
Paginated
(
c
,
out
,
pag
.
Total
,
page
,
pageSize
)
}
// GetByID handles getting a channel by ID
// GET /api/v1/admin/channels/:id
func
(
h
*
ChannelHandler
)
GetByID
(
c
*
gin
.
Context
)
{
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid channel ID"
)
return
}
channel
,
err
:=
h
.
channelService
.
GetByID
(
c
.
Request
.
Context
(),
id
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
channelToResponse
(
channel
))
}
// Create handles creating a new channel
// POST /api/v1/admin/channels
func
(
h
*
ChannelHandler
)
Create
(
c
*
gin
.
Context
)
{
var
req
createChannelRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
channel
,
err
:=
h
.
channelService
.
Create
(
c
.
Request
.
Context
(),
&
service
.
CreateChannelInput
{
Name
:
req
.
Name
,
Description
:
req
.
Description
,
GroupIDs
:
req
.
GroupIDs
,
ModelPricing
:
pricingRequestToService
(
req
.
ModelPricing
),
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
channelToResponse
(
channel
))
}
// Update handles updating a channel
// PUT /api/v1/admin/channels/:id
func
(
h
*
ChannelHandler
)
Update
(
c
*
gin
.
Context
)
{
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid channel ID"
)
return
}
var
req
updateChannelRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
input
:=
&
service
.
UpdateChannelInput
{
Name
:
req
.
Name
,
Description
:
req
.
Description
,
Status
:
req
.
Status
,
GroupIDs
:
req
.
GroupIDs
,
}
if
req
.
ModelPricing
!=
nil
{
pricing
:=
pricingRequestToService
(
*
req
.
ModelPricing
)
input
.
ModelPricing
=
&
pricing
}
channel
,
err
:=
h
.
channelService
.
Update
(
c
.
Request
.
Context
(),
id
,
input
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
channelToResponse
(
channel
))
}
// Delete handles deleting a channel
// DELETE /api/v1/admin/channels/:id
func
(
h
*
ChannelHandler
)
Delete
(
c
*
gin
.
Context
)
{
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid channel ID"
)
return
}
if
err
:=
h
.
channelService
.
Delete
(
c
.
Request
.
Context
(),
id
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Channel deleted successfully"
})
}
backend/internal/handler/handler.go
View file @
91c9b8d0
...
...
@@ -30,6 +30,7 @@ type AdminHandlers struct {
TLSFingerprintProfile
*
admin
.
TLSFingerprintProfileHandler
APIKey
*
admin
.
AdminAPIKeyHandler
ScheduledTest
*
admin
.
ScheduledTestHandler
Channel
*
admin
.
ChannelHandler
}
// Handlers contains all HTTP handlers
...
...
backend/internal/handler/wire.go
View file @
91c9b8d0
...
...
@@ -33,6 +33,7 @@ func ProvideAdminHandlers(
tlsFingerprintProfileHandler
*
admin
.
TLSFingerprintProfileHandler
,
apiKeyHandler
*
admin
.
AdminAPIKeyHandler
,
scheduledTestHandler
*
admin
.
ScheduledTestHandler
,
channelHandler
*
admin
.
ChannelHandler
,
)
*
AdminHandlers
{
return
&
AdminHandlers
{
Dashboard
:
dashboardHandler
,
...
...
@@ -59,6 +60,7 @@ func ProvideAdminHandlers(
TLSFingerprintProfile
:
tlsFingerprintProfileHandler
,
APIKey
:
apiKeyHandler
,
ScheduledTest
:
scheduledTestHandler
,
Channel
:
channelHandler
,
}
}
...
...
@@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet(
admin
.
NewTLSFingerprintProfileHandler
,
admin
.
NewAdminAPIKeyHandler
,
admin
.
NewScheduledTestHandler
,
admin
.
NewChannelHandler
,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers
,
...
...
backend/internal/repository/channel_repo.go
0 → 100644
View file @
91c9b8d0
package
repository
import
(
"context"
"database/sql"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type
channelRepository
struct
{
db
*
sql
.
DB
}
// NewChannelRepository 创建渠道数据访问实例
func
NewChannelRepository
(
db
*
sql
.
DB
)
service
.
ChannelRepository
{
return
&
channelRepository
{
db
:
db
}
}
// runInTx 在事务中执行 fn,成功 commit,失败 rollback。
func
(
r
*
channelRepository
)
runInTx
(
ctx
context
.
Context
,
fn
func
(
tx
*
sql
.
Tx
)
error
)
error
{
tx
,
err
:=
r
.
db
.
BeginTx
(
ctx
,
nil
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"begin tx: %w"
,
err
)
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
if
err
:=
fn
(
tx
);
err
!=
nil
{
return
err
}
return
tx
.
Commit
()
}
func
(
r
*
channelRepository
)
Create
(
ctx
context
.
Context
,
channel
*
service
.
Channel
)
error
{
return
r
.
runInTx
(
ctx
,
func
(
tx
*
sql
.
Tx
)
error
{
err
:=
tx
.
QueryRowContext
(
ctx
,
`INSERT INTO channels (name, description, status) VALUES ($1, $2, $3)
RETURNING id, created_at, updated_at`
,
channel
.
Name
,
channel
.
Description
,
channel
.
Status
,
)
.
Scan
(
&
channel
.
ID
,
&
channel
.
CreatedAt
,
&
channel
.
UpdatedAt
)
if
err
!=
nil
{
if
isUniqueViolation
(
err
)
{
return
service
.
ErrChannelExists
}
return
fmt
.
Errorf
(
"insert channel: %w"
,
err
)
}
// 设置分组关联
if
len
(
channel
.
GroupIDs
)
>
0
{
if
err
:=
setGroupIDsTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
GroupIDs
);
err
!=
nil
{
return
err
}
}
// 设置模型定价
if
len
(
channel
.
ModelPricing
)
>
0
{
if
err
:=
replaceModelPricingTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
ModelPricing
);
err
!=
nil
{
return
err
}
}
return
nil
})
}
func
(
r
*
channelRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Channel
,
error
)
{
ch
:=
&
service
.
Channel
{}
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
`SELECT id, name, description, status, created_at, updated_at
FROM channels WHERE id = $1`
,
id
,
)
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
)
if
err
==
sql
.
ErrNoRows
{
return
nil
,
service
.
ErrChannelNotFound
}
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get channel: %w"
,
err
)
}
groupIDs
,
err
:=
r
.
GetGroupIDs
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
ch
.
GroupIDs
=
groupIDs
pricing
,
err
:=
r
.
ListModelPricing
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
ch
.
ModelPricing
=
pricing
return
ch
,
nil
}
func
(
r
*
channelRepository
)
Update
(
ctx
context
.
Context
,
channel
*
service
.
Channel
)
error
{
return
r
.
runInTx
(
ctx
,
func
(
tx
*
sql
.
Tx
)
error
{
result
,
err
:=
tx
.
ExecContext
(
ctx
,
`UPDATE channels SET name = $1, description = $2, status = $3, updated_at = NOW()
WHERE id = $4`
,
channel
.
Name
,
channel
.
Description
,
channel
.
Status
,
channel
.
ID
,
)
if
err
!=
nil
{
if
isUniqueViolation
(
err
)
{
return
service
.
ErrChannelExists
}
return
fmt
.
Errorf
(
"update channel: %w"
,
err
)
}
rows
,
_
:=
result
.
RowsAffected
()
if
rows
==
0
{
return
service
.
ErrChannelNotFound
}
// 更新分组关联
if
channel
.
GroupIDs
!=
nil
{
if
err
:=
setGroupIDsTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
GroupIDs
);
err
!=
nil
{
return
err
}
}
// 更新模型定价
if
channel
.
ModelPricing
!=
nil
{
if
err
:=
replaceModelPricingTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
ModelPricing
);
err
!=
nil
{
return
err
}
}
return
nil
})
}
func
(
r
*
channelRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
result
,
err
:=
r
.
db
.
ExecContext
(
ctx
,
`DELETE FROM channels WHERE id = $1`
,
id
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"delete channel: %w"
,
err
)
}
rows
,
_
:=
result
.
RowsAffected
()
if
rows
==
0
{
return
service
.
ErrChannelNotFound
}
return
nil
}
func
(
r
*
channelRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
search
string
)
([]
service
.
Channel
,
*
pagination
.
PaginationResult
,
error
)
{
where
:=
[]
string
{
"1=1"
}
args
:=
[]
any
{}
argIdx
:=
1
if
status
!=
""
{
where
=
append
(
where
,
fmt
.
Sprintf
(
"c.status = $%d"
,
argIdx
))
args
=
append
(
args
,
status
)
argIdx
++
}
if
search
!=
""
{
where
=
append
(
where
,
fmt
.
Sprintf
(
"(c.name ILIKE $%d OR c.description ILIKE $%d)"
,
argIdx
,
argIdx
))
args
=
append
(
args
,
"%"
+
escapeLike
(
search
)
+
"%"
)
argIdx
++
}
whereClause
:=
strings
.
Join
(
where
,
" AND "
)
// 计数
var
total
int64
countQuery
:=
fmt
.
Sprintf
(
"SELECT COUNT(*) FROM channels c WHERE %s"
,
whereClause
)
if
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
countQuery
,
args
...
)
.
Scan
(
&
total
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"count channels: %w"
,
err
)
}
pageSize
:=
params
.
Limit
()
// 约束在 [1, 100]
page
:=
params
.
Page
if
page
<
1
{
page
=
1
}
offset
:=
(
page
-
1
)
*
pageSize
// 查询 channel 列表
dataQuery
:=
fmt
.
Sprintf
(
`SELECT c.id, c.name, c.description, c.status, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`
,
whereClause
,
argIdx
,
argIdx
+
1
,
)
args
=
append
(
args
,
pageSize
,
offset
)
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
dataQuery
,
args
...
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"query channels: %w"
,
err
)
}
defer
rows
.
Close
()
var
channels
[]
service
.
Channel
var
channelIDs
[]
int64
for
rows
.
Next
()
{
var
ch
service
.
Channel
if
err
:=
rows
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"scan channel: %w"
,
err
)
}
channels
=
append
(
channels
,
ch
)
channelIDs
=
append
(
channelIDs
,
ch
.
ID
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"iterate channels: %w"
,
err
)
}
// 批量加载分组 ID 和模型定价(避免 N+1)
if
len
(
channelIDs
)
>
0
{
groupMap
,
err
:=
r
.
batchLoadGroupIDs
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
pricingMap
,
err
:=
r
.
batchLoadModelPricing
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
for
i
:=
range
channels
{
channels
[
i
]
.
GroupIDs
=
groupMap
[
channels
[
i
]
.
ID
]
channels
[
i
]
.
ModelPricing
=
pricingMap
[
channels
[
i
]
.
ID
]
}
}
pages
:=
0
if
total
>
0
{
pages
=
int
((
total
+
int64
(
pageSize
)
-
1
)
/
int64
(
pageSize
))
}
paginationResult
:=
&
pagination
.
PaginationResult
{
Total
:
total
,
Page
:
page
,
PageSize
:
pageSize
,
Pages
:
pages
,
}
return
channels
,
paginationResult
,
nil
}
func
(
r
*
channelRepository
)
ListAll
(
ctx
context
.
Context
)
([]
service
.
Channel
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, name, description, status, created_at, updated_at FROM channels ORDER BY id`
,
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query all channels: %w"
,
err
)
}
defer
rows
.
Close
()
var
channels
[]
service
.
Channel
var
channelIDs
[]
int64
for
rows
.
Next
()
{
var
ch
service
.
Channel
if
err
:=
rows
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan channel: %w"
,
err
)
}
channels
=
append
(
channels
,
ch
)
channelIDs
=
append
(
channelIDs
,
ch
.
ID
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate channels: %w"
,
err
)
}
if
len
(
channelIDs
)
==
0
{
return
channels
,
nil
}
// 批量加载分组 ID
groupMap
,
err
:=
r
.
batchLoadGroupIDs
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
err
}
// 批量加载模型定价
pricingMap
,
err
:=
r
.
batchLoadModelPricing
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
err
}
for
i
:=
range
channels
{
channels
[
i
]
.
GroupIDs
=
groupMap
[
channels
[
i
]
.
ID
]
channels
[
i
]
.
ModelPricing
=
pricingMap
[
channels
[
i
]
.
ID
]
}
return
channels
,
nil
}
// --- 批量加载辅助方法 ---
// batchLoadGroupIDs 批量加载多个渠道的分组 ID
func
(
r
*
channelRepository
)
batchLoadGroupIDs
(
ctx
context
.
Context
,
channelIDs
[]
int64
)
(
map
[
int64
][]
int64
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT channel_id, group_id FROM channel_groups
WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`
,
pq
.
Array
(
channelIDs
),
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"batch load group ids: %w"
,
err
)
}
defer
rows
.
Close
()
groupMap
:=
make
(
map
[
int64
][]
int64
,
len
(
channelIDs
))
for
rows
.
Next
()
{
var
channelID
,
groupID
int64
if
err
:=
rows
.
Scan
(
&
channelID
,
&
groupID
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan group id: %w"
,
err
)
}
groupMap
[
channelID
]
=
append
(
groupMap
[
channelID
],
groupID
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate group ids: %w"
,
err
)
}
return
groupMap
,
nil
}
func
(
r
*
channelRepository
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
var
exists
bool
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`
,
name
,
)
.
Scan
(
&
exists
)
return
exists
,
err
}
func
(
r
*
channelRepository
)
ExistsByNameExcluding
(
ctx
context
.
Context
,
name
string
,
excludeID
int64
)
(
bool
,
error
)
{
var
exists
bool
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`
,
name
,
excludeID
,
)
.
Scan
(
&
exists
)
return
exists
,
err
}
// --- 分组关联 ---
func
(
r
*
channelRepository
)
GetGroupIDs
(
ctx
context
.
Context
,
channelID
int64
)
([]
int64
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`
,
channelID
,
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group ids: %w"
,
err
)
}
defer
rows
.
Close
()
var
ids
[]
int64
for
rows
.
Next
()
{
var
id
int64
if
err
:=
rows
.
Scan
(
&
id
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan group id: %w"
,
err
)
}
ids
=
append
(
ids
,
id
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate group ids: %w"
,
err
)
}
return
ids
,
nil
}
func
(
r
*
channelRepository
)
SetGroupIDs
(
ctx
context
.
Context
,
channelID
int64
,
groupIDs
[]
int64
)
error
{
return
setGroupIDsTx
(
ctx
,
r
.
db
,
channelID
,
groupIDs
)
}
func
(
r
*
channelRepository
)
GetChannelIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
channelID
int64
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
`SELECT channel_id FROM channel_groups WHERE group_id = $1`
,
groupID
,
)
.
Scan
(
&
channelID
)
if
err
==
sql
.
ErrNoRows
{
return
0
,
nil
}
return
channelID
,
err
}
func
(
r
*
channelRepository
)
GetGroupsInOtherChannels
(
ctx
context
.
Context
,
channelID
int64
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
if
len
(
groupIDs
)
==
0
{
return
nil
,
nil
}
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`
,
pq
.
Array
(
groupIDs
),
channelID
,
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get groups in other channels: %w"
,
err
)
}
defer
rows
.
Close
()
var
conflicting
[]
int64
for
rows
.
Next
()
{
var
id
int64
if
err
:=
rows
.
Scan
(
&
id
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan conflicting group id: %w"
,
err
)
}
conflicting
=
append
(
conflicting
,
id
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate conflicting group ids: %w"
,
err
)
}
return
conflicting
,
nil
}
backend/internal/repository/channel_repo_pricing.go
0 → 100644
View file @
91c9b8d0
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 模型定价 ---
func
(
r
*
channelRepository
)
ListModelPricing
(
ctx
context
.
Context
,
channelID
int64
)
([]
service
.
ChannelModelPricing
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`
,
channelID
,
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list model pricing: %w"
,
err
)
}
defer
rows
.
Close
()
result
,
pricingIDs
,
err
:=
scanModelPricingRows
(
rows
)
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
pricingIDs
)
>
0
{
intervalMap
,
err
:=
r
.
batchLoadIntervals
(
ctx
,
pricingIDs
)
if
err
!=
nil
{
return
nil
,
err
}
for
i
:=
range
result
{
result
[
i
]
.
Intervals
=
intervalMap
[
result
[
i
]
.
ID
]
}
}
return
result
,
nil
}
func
(
r
*
channelRepository
)
CreateModelPricing
(
ctx
context
.
Context
,
pricing
*
service
.
ChannelModelPricing
)
error
{
return
createModelPricingExec
(
ctx
,
r
.
db
,
pricing
)
}
func
(
r
*
channelRepository
)
UpdateModelPricing
(
ctx
context
.
Context
,
pricing
*
service
.
ChannelModelPricing
)
error
{
modelsJSON
,
err
:=
json
.
Marshal
(
pricing
.
Models
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal models: %w"
,
err
)
}
billingMode
:=
pricing
.
BillingMode
if
billingMode
==
""
{
billingMode
=
service
.
BillingModeToken
}
result
,
err
:=
r
.
db
.
ExecContext
(
ctx
,
`UPDATE channel_model_pricing
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, updated_at = NOW()
WHERE id = $8`
,
modelsJSON
,
billingMode
,
pricing
.
InputPrice
,
pricing
.
OutputPrice
,
pricing
.
CacheWritePrice
,
pricing
.
CacheReadPrice
,
pricing
.
ImageOutputPrice
,
pricing
.
ID
,
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"update model pricing: %w"
,
err
)
}
rows
,
_
:=
result
.
RowsAffected
()
if
rows
==
0
{
return
fmt
.
Errorf
(
"pricing entry not found: %d"
,
pricing
.
ID
)
}
return
nil
}
func
(
r
*
channelRepository
)
DeleteModelPricing
(
ctx
context
.
Context
,
id
int64
)
error
{
_
,
err
:=
r
.
db
.
ExecContext
(
ctx
,
`DELETE FROM channel_model_pricing WHERE id = $1`
,
id
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"delete model pricing: %w"
,
err
)
}
return
nil
}
func
(
r
*
channelRepository
)
ReplaceModelPricing
(
ctx
context
.
Context
,
channelID
int64
,
pricingList
[]
service
.
ChannelModelPricing
)
error
{
return
r
.
runInTx
(
ctx
,
func
(
tx
*
sql
.
Tx
)
error
{
return
replaceModelPricingTx
(
ctx
,
tx
,
channelID
,
pricingList
)
})
}
// --- 批量加载辅助方法 ---
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
func
(
r
*
channelRepository
)
batchLoadModelPricing
(
ctx
context
.
Context
,
channelIDs
[]
int64
)
(
map
[
int64
][]
service
.
ChannelModelPricing
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`
,
pq
.
Array
(
channelIDs
),
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"batch load model pricing: %w"
,
err
)
}
defer
rows
.
Close
()
allPricing
,
allPricingIDs
,
err
:=
scanModelPricingRows
(
rows
)
if
err
!=
nil
{
return
nil
,
err
}
// 按 channelID 分组
pricingMap
:=
make
(
map
[
int64
][]
service
.
ChannelModelPricing
,
len
(
channelIDs
))
for
_
,
p
:=
range
allPricing
{
pricingMap
[
p
.
ChannelID
]
=
append
(
pricingMap
[
p
.
ChannelID
],
p
)
}
// 批量加载所有区间
if
len
(
allPricingIDs
)
>
0
{
intervalMap
,
err
:=
r
.
batchLoadIntervals
(
ctx
,
allPricingIDs
)
if
err
!=
nil
{
return
nil
,
err
}
for
chID
:=
range
pricingMap
{
for
i
:=
range
pricingMap
[
chID
]
{
pricingMap
[
chID
][
i
]
.
Intervals
=
intervalMap
[
pricingMap
[
chID
][
i
]
.
ID
]
}
}
}
return
pricingMap
,
nil
}
// batchLoadIntervals 批量加载多个定价条目的区间
func
(
r
*
channelRepository
)
batchLoadIntervals
(
ctx
context
.
Context
,
pricingIDs
[]
int64
)
(
map
[
int64
][]
service
.
PricingInterval
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`
,
pq
.
Array
(
pricingIDs
),
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"batch load intervals: %w"
,
err
)
}
defer
rows
.
Close
()
intervalMap
:=
make
(
map
[
int64
][]
service
.
PricingInterval
,
len
(
pricingIDs
))
for
rows
.
Next
()
{
var
iv
service
.
PricingInterval
if
err
:=
rows
.
Scan
(
&
iv
.
ID
,
&
iv
.
PricingID
,
&
iv
.
MinTokens
,
&
iv
.
MaxTokens
,
&
iv
.
TierLabel
,
&
iv
.
InputPrice
,
&
iv
.
OutputPrice
,
&
iv
.
CacheWritePrice
,
&
iv
.
CacheReadPrice
,
&
iv
.
PerRequestPrice
,
&
iv
.
SortOrder
,
&
iv
.
CreatedAt
,
&
iv
.
UpdatedAt
,
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan interval: %w"
,
err
)
}
intervalMap
[
iv
.
PricingID
]
=
append
(
intervalMap
[
iv
.
PricingID
],
iv
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate intervals: %w"
,
err
)
}
return
intervalMap
,
nil
}
// --- 共享 scan 辅助 ---
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
func
scanModelPricingRows
(
rows
*
sql
.
Rows
)
([]
service
.
ChannelModelPricing
,
[]
int64
,
error
)
{
var
result
[]
service
.
ChannelModelPricing
var
pricingIDs
[]
int64
for
rows
.
Next
()
{
var
p
service
.
ChannelModelPricing
var
modelsJSON
[]
byte
if
err
:=
rows
.
Scan
(
&
p
.
ID
,
&
p
.
ChannelID
,
&
modelsJSON
,
&
p
.
BillingMode
,
&
p
.
InputPrice
,
&
p
.
OutputPrice
,
&
p
.
CacheWritePrice
,
&
p
.
CacheReadPrice
,
&
p
.
ImageOutputPrice
,
&
p
.
CreatedAt
,
&
p
.
UpdatedAt
,
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"scan model pricing: %w"
,
err
)
}
if
err
:=
json
.
Unmarshal
(
modelsJSON
,
&
p
.
Models
);
err
!=
nil
{
p
.
Models
=
[]
string
{}
}
pricingIDs
=
append
(
pricingIDs
,
p
.
ID
)
result
=
append
(
result
,
p
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"iterate model pricing: %w"
,
err
)
}
return
result
,
pricingIDs
,
nil
}
// --- 事务内辅助方法 ---
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
type
dbExec
interface
{
ExecContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
sql
.
Result
,
error
)
QueryContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
*
sql
.
Rows
,
error
)
QueryRowContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
*
sql
.
Row
}
func
setGroupIDsTx
(
ctx
context
.
Context
,
exec
dbExec
,
channelID
int64
,
groupIDs
[]
int64
)
error
{
if
_
,
err
:=
exec
.
ExecContext
(
ctx
,
`DELETE FROM channel_groups WHERE channel_id = $1`
,
channelID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete old group associations: %w"
,
err
)
}
if
len
(
groupIDs
)
==
0
{
return
nil
}
_
,
err
:=
exec
.
ExecContext
(
ctx
,
`INSERT INTO channel_groups (channel_id, group_id)
SELECT $1, unnest($2::bigint[])`
,
channelID
,
pq
.
Array
(
groupIDs
),
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"insert group associations: %w"
,
err
)
}
return
nil
}
func
createModelPricingExec
(
ctx
context
.
Context
,
exec
dbExec
,
pricing
*
service
.
ChannelModelPricing
)
error
{
modelsJSON
,
err
:=
json
.
Marshal
(
pricing
.
Models
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal models: %w"
,
err
)
}
billingMode
:=
pricing
.
BillingMode
if
billingMode
==
""
{
billingMode
=
service
.
BillingModeToken
}
err
=
exec
.
QueryRowContext
(
ctx
,
`INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at`
,
pricing
.
ChannelID
,
modelsJSON
,
billingMode
,
pricing
.
InputPrice
,
pricing
.
OutputPrice
,
pricing
.
CacheWritePrice
,
pricing
.
CacheReadPrice
,
pricing
.
ImageOutputPrice
,
)
.
Scan
(
&
pricing
.
ID
,
&
pricing
.
CreatedAt
,
&
pricing
.
UpdatedAt
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"insert model pricing: %w"
,
err
)
}
for
i
:=
range
pricing
.
Intervals
{
pricing
.
Intervals
[
i
]
.
PricingID
=
pricing
.
ID
if
err
:=
createIntervalExec
(
ctx
,
exec
,
&
pricing
.
Intervals
[
i
]);
err
!=
nil
{
return
err
}
}
return
nil
}
func
createIntervalExec
(
ctx
context
.
Context
,
exec
dbExec
,
iv
*
service
.
PricingInterval
)
error
{
return
exec
.
QueryRowContext
(
ctx
,
`INSERT INTO channel_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`
,
iv
.
PricingID
,
iv
.
MinTokens
,
iv
.
MaxTokens
,
iv
.
TierLabel
,
iv
.
InputPrice
,
iv
.
OutputPrice
,
iv
.
CacheWritePrice
,
iv
.
CacheReadPrice
,
iv
.
PerRequestPrice
,
iv
.
SortOrder
,
)
.
Scan
(
&
iv
.
ID
,
&
iv
.
CreatedAt
,
&
iv
.
UpdatedAt
)
}
func
replaceModelPricingTx
(
ctx
context
.
Context
,
exec
dbExec
,
channelID
int64
,
pricingList
[]
service
.
ChannelModelPricing
)
error
{
if
_
,
err
:=
exec
.
ExecContext
(
ctx
,
`DELETE FROM channel_model_pricing WHERE channel_id = $1`
,
channelID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete old model pricing: %w"
,
err
)
}
for
i
:=
range
pricingList
{
pricingList
[
i
]
.
ChannelID
=
channelID
if
err
:=
createModelPricingExec
(
ctx
,
exec
,
&
pricingList
[
i
]);
err
!=
nil
{
return
fmt
.
Errorf
(
"insert model pricing: %w"
,
err
)
}
}
return
nil
}
// isUniqueViolation 检查 pq 唯一约束违反错误
func
isUniqueViolation
(
err
error
)
bool
{
if
pqErr
,
ok
:=
err
.
(
*
pq
.
Error
);
ok
{
return
pqErr
.
Code
==
"23505"
}
return
false
}
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
func
escapeLike
(
s
string
)
string
{
s
=
strings
.
ReplaceAll
(
s
,
`\`
,
`\\`
)
s
=
strings
.
ReplaceAll
(
s
,
`%`
,
`\%`
)
s
=
strings
.
ReplaceAll
(
s
,
`_`
,
`\_`
)
return
s
}
backend/internal/repository/wire.go
View file @
91c9b8d0
...
...
@@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
NewUserGroupRateRepository
,
NewErrorPassthroughRepository
,
NewTLSFingerprintProfileRepository
,
NewChannelRepository
,
// Cache implementations
NewGatewayCache
,
...
...
backend/internal/server/routes/admin.go
View file @
91c9b8d0
...
...
@@ -87,6 +87,9 @@ func RegisterAdminRoutes(
// 定时测试计划
registerScheduledTestRoutes
(
admin
,
h
)
// 渠道管理
registerChannelRoutes
(
admin
,
h
)
}
}
...
...
@@ -567,3 +570,14 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand
profiles
.
DELETE
(
"/:id"
,
h
.
Admin
.
TLSFingerprintProfile
.
Delete
)
}
}
func
registerChannelRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
channels
:=
admin
.
Group
(
"/channels"
)
{
channels
.
GET
(
""
,
h
.
Admin
.
Channel
.
List
)
channels
.
GET
(
"/:id"
,
h
.
Admin
.
Channel
.
GetByID
)
channels
.
POST
(
""
,
h
.
Admin
.
Channel
.
Create
)
channels
.
PUT
(
"/:id"
,
h
.
Admin
.
Channel
.
Update
)
channels
.
DELETE
(
"/:id"
,
h
.
Admin
.
Channel
.
Delete
)
}
}
backend/internal/service/billing_service.go
View file @
91c9b8d0
...
...
@@ -371,13 +371,193 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
return
nil
,
fmt
.
Errorf
(
"pricing not found for model: %s"
,
model
)
}
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
// 仅覆盖渠道中非 nil 的价格字段,nil 字段使用默认定价
func
(
s
*
BillingService
)
GetModelPricingWithChannel
(
model
string
,
channelPricing
*
ChannelModelPricing
)
(
*
ModelPricing
,
error
)
{
pricing
,
err
:=
s
.
GetModelPricing
(
model
)
if
err
!=
nil
{
return
nil
,
err
}
if
channelPricing
==
nil
{
return
pricing
,
nil
}
if
channelPricing
.
InputPrice
!=
nil
{
pricing
.
InputPricePerToken
=
*
channelPricing
.
InputPrice
pricing
.
InputPricePerTokenPriority
=
*
channelPricing
.
InputPrice
}
if
channelPricing
.
OutputPrice
!=
nil
{
pricing
.
OutputPricePerToken
=
*
channelPricing
.
OutputPrice
pricing
.
OutputPricePerTokenPriority
=
*
channelPricing
.
OutputPrice
}
if
channelPricing
.
CacheWritePrice
!=
nil
{
pricing
.
CacheCreationPricePerToken
=
*
channelPricing
.
CacheWritePrice
pricing
.
CacheCreation5mPrice
=
*
channelPricing
.
CacheWritePrice
pricing
.
CacheCreation1hPrice
=
*
channelPricing
.
CacheWritePrice
}
if
channelPricing
.
CacheReadPrice
!=
nil
{
pricing
.
CacheReadPricePerToken
=
*
channelPricing
.
CacheReadPrice
pricing
.
CacheReadPricePerTokenPriority
=
*
channelPricing
.
CacheReadPrice
}
return
pricing
,
nil
}
// CalculateCostWithChannel 使用渠道定价计算费用
// Deprecated: 使用 CalculateCostUnified 代替
func
(
s
*
BillingService
)
CalculateCostWithChannel
(
model
string
,
tokens
UsageTokens
,
rateMultiplier
float64
,
channelPricing
*
ChannelModelPricing
)
(
*
CostBreakdown
,
error
)
{
return
s
.
calculateCostInternal
(
model
,
tokens
,
rateMultiplier
,
""
,
channelPricing
)
}
// --- 统一计费入口 ---
// CostInput 统一计费输入
type
CostInput
struct
{
Ctx
context
.
Context
Model
string
GroupID
*
int64
// 用于渠道定价查找
Tokens
UsageTokens
RequestCount
int
// 按次计费时使用
SizeTier
string
// 按次/图片模式的层级标签("1K","2K","4K","HD" 等)
RateMultiplier
float64
ServiceTier
string
// "priority","flex","" 等
Resolver
*
ModelPricingResolver
// 定价解析器
}
// CalculateCostUnified 统一计费入口,支持三种计费模式。
// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。
func
(
s
*
BillingService
)
CalculateCostUnified
(
input
CostInput
)
(
*
CostBreakdown
,
error
)
{
if
input
.
Resolver
==
nil
{
// 无 Resolver,回退到旧路径
return
s
.
calculateCostInternal
(
input
.
Model
,
input
.
Tokens
,
input
.
RateMultiplier
,
input
.
ServiceTier
,
nil
)
}
resolved
:=
input
.
Resolver
.
Resolve
(
input
.
Ctx
,
PricingInput
{
Model
:
input
.
Model
,
GroupID
:
input
.
GroupID
,
})
if
input
.
RateMultiplier
<=
0
{
input
.
RateMultiplier
=
1.0
}
switch
resolved
.
Mode
{
case
BillingModePerRequest
,
BillingModeImage
:
return
s
.
calculatePerRequestCost
(
resolved
,
input
)
default
:
// BillingModeToken
return
s
.
calculateTokenCost
(
resolved
,
input
)
}
}
// calculateTokenCost 按 token 区间计费
func
(
s
*
BillingService
)
calculateTokenCost
(
resolved
*
ResolvedPricing
,
input
CostInput
)
(
*
CostBreakdown
,
error
)
{
totalContext
:=
input
.
Tokens
.
InputTokens
+
input
.
Tokens
.
CacheReadTokens
pricing
:=
input
.
Resolver
.
GetIntervalPricing
(
resolved
,
totalContext
)
if
pricing
==
nil
{
return
nil
,
fmt
.
Errorf
(
"no pricing available for model: %s"
,
input
.
Model
)
}
pricing
=
s
.
applyModelSpecificPricingPolicy
(
input
.
Model
,
pricing
)
breakdown
:=
&
CostBreakdown
{}
inputPricePerToken
:=
pricing
.
InputPricePerToken
outputPricePerToken
:=
pricing
.
OutputPricePerToken
cacheReadPricePerToken
:=
pricing
.
CacheReadPricePerToken
tierMultiplier
:=
1.0
if
usePriorityServiceTierPricing
(
input
.
ServiceTier
,
pricing
)
{
if
pricing
.
InputPricePerTokenPriority
>
0
{
inputPricePerToken
=
pricing
.
InputPricePerTokenPriority
}
if
pricing
.
OutputPricePerTokenPriority
>
0
{
outputPricePerToken
=
pricing
.
OutputPricePerTokenPriority
}
if
pricing
.
CacheReadPricePerTokenPriority
>
0
{
cacheReadPricePerToken
=
pricing
.
CacheReadPricePerTokenPriority
}
}
else
{
tierMultiplier
=
serviceTierCostMultiplier
(
input
.
ServiceTier
)
}
// 长上下文定价(仅在无区间定价时应用,区间定价已包含上下文分层)
if
len
(
resolved
.
Intervals
)
==
0
&&
s
.
shouldApplySessionLongContextPricing
(
input
.
Tokens
,
pricing
)
{
inputPricePerToken
*=
pricing
.
LongContextInputMultiplier
outputPricePerToken
*=
pricing
.
LongContextOutputMultiplier
}
breakdown
.
InputCost
=
float64
(
input
.
Tokens
.
InputTokens
)
*
inputPricePerToken
breakdown
.
OutputCost
=
float64
(
input
.
Tokens
.
OutputTokens
)
*
outputPricePerToken
if
pricing
.
SupportsCacheBreakdown
&&
(
pricing
.
CacheCreation5mPrice
>
0
||
pricing
.
CacheCreation1hPrice
>
0
)
{
if
input
.
Tokens
.
CacheCreation5mTokens
==
0
&&
input
.
Tokens
.
CacheCreation1hTokens
==
0
&&
input
.
Tokens
.
CacheCreationTokens
>
0
{
breakdown
.
CacheCreationCost
=
float64
(
input
.
Tokens
.
CacheCreationTokens
)
*
pricing
.
CacheCreation5mPrice
}
else
{
breakdown
.
CacheCreationCost
=
float64
(
input
.
Tokens
.
CacheCreation5mTokens
)
*
pricing
.
CacheCreation5mPrice
+
float64
(
input
.
Tokens
.
CacheCreation1hTokens
)
*
pricing
.
CacheCreation1hPrice
}
}
else
{
breakdown
.
CacheCreationCost
=
float64
(
input
.
Tokens
.
CacheCreationTokens
)
*
pricing
.
CacheCreationPricePerToken
}
breakdown
.
CacheReadCost
=
float64
(
input
.
Tokens
.
CacheReadTokens
)
*
cacheReadPricePerToken
if
tierMultiplier
!=
1.0
{
breakdown
.
InputCost
*=
tierMultiplier
breakdown
.
OutputCost
*=
tierMultiplier
breakdown
.
CacheCreationCost
*=
tierMultiplier
breakdown
.
CacheReadCost
*=
tierMultiplier
}
breakdown
.
TotalCost
=
breakdown
.
InputCost
+
breakdown
.
OutputCost
+
breakdown
.
CacheCreationCost
+
breakdown
.
CacheReadCost
breakdown
.
ActualCost
=
breakdown
.
TotalCost
*
input
.
RateMultiplier
return
breakdown
,
nil
}
// calculatePerRequestCost 按次/图片计费
func
(
s
*
BillingService
)
calculatePerRequestCost
(
resolved
*
ResolvedPricing
,
input
CostInput
)
(
*
CostBreakdown
,
error
)
{
count
:=
input
.
RequestCount
if
count
<=
0
{
count
=
1
}
var
unitPrice
float64
if
input
.
SizeTier
!=
""
{
unitPrice
=
input
.
Resolver
.
GetRequestTierPrice
(
resolved
,
input
.
SizeTier
)
}
if
unitPrice
==
0
{
totalContext
:=
input
.
Tokens
.
InputTokens
+
input
.
Tokens
.
CacheReadTokens
unitPrice
=
input
.
Resolver
.
GetRequestTierPriceByContext
(
resolved
,
totalContext
)
}
totalCost
:=
unitPrice
*
float64
(
count
)
actualCost
:=
totalCost
*
input
.
RateMultiplier
return
&
CostBreakdown
{
TotalCost
:
totalCost
,
ActualCost
:
actualCost
,
},
nil
}
// CalculateCost 计算使用费用
func
(
s
*
BillingService
)
CalculateCost
(
model
string
,
tokens
UsageTokens
,
rateMultiplier
float64
)
(
*
CostBreakdown
,
error
)
{
return
s
.
C
alculateCost
WithServiceTier
(
model
,
tokens
,
rateMultiplier
,
""
)
return
s
.
c
alculateCost
Internal
(
model
,
tokens
,
rateMultiplier
,
""
,
nil
)
}
func
(
s
*
BillingService
)
CalculateCostWithServiceTier
(
model
string
,
tokens
UsageTokens
,
rateMultiplier
float64
,
serviceTier
string
)
(
*
CostBreakdown
,
error
)
{
pricing
,
err
:=
s
.
GetModelPricing
(
model
)
return
s
.
calculateCostInternal
(
model
,
tokens
,
rateMultiplier
,
serviceTier
,
nil
)
}
func
(
s
*
BillingService
)
calculateCostInternal
(
model
string
,
tokens
UsageTokens
,
rateMultiplier
float64
,
serviceTier
string
,
channelPricing
*
ChannelModelPricing
)
(
*
CostBreakdown
,
error
)
{
var
pricing
*
ModelPricing
var
err
error
if
channelPricing
!=
nil
{
pricing
,
err
=
s
.
GetModelPricingWithChannel
(
model
,
channelPricing
)
}
else
{
pricing
,
err
=
s
.
GetModelPricing
(
model
)
}
if
err
!=
nil
{
return
nil
,
err
}
...
...
backend/internal/service/channel.go
0 → 100644
View file @
91c9b8d0
package
service
import
(
"strings"
"time"
)
// BillingMode 计费模式
type
BillingMode
string
const
(
BillingModeToken
BillingMode
=
"token"
// 按 token 区间计费
BillingModePerRequest
BillingMode
=
"per_request"
// 按次计费(支持上下文窗口分层)
BillingModeImage
BillingMode
=
"image"
// 图片计费(当前按次,预留 token 计费)
)
// IsValid 检查 BillingMode 是否为合法值
func
(
m
BillingMode
)
IsValid
()
bool
{
switch
m
{
case
BillingModeToken
,
BillingModePerRequest
,
BillingModeImage
,
""
:
return
true
}
return
false
}
// Channel 渠道实体
type
Channel
struct
{
ID
int64
Name
string
Description
string
Status
string
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
// 关联的分组 ID 列表
GroupIDs
[]
int64
// 模型定价列表
ModelPricing
[]
ChannelModelPricing
}
// ChannelModelPricing 渠道模型定价条目
type
ChannelModelPricing
struct
{
ID
int64
ChannelID
int64
Models
[]
string
// 绑定的模型列表
BillingMode
BillingMode
// 计费模式
InputPrice
*
float64
// 每 token 输入价格(USD)— 向后兼容 flat 定价
OutputPrice
*
float64
// 每 token 输出价格(USD)
CacheWritePrice
*
float64
// 缓存写入价格
CacheReadPrice
*
float64
// 缓存读取价格
ImageOutputPrice
*
float64
// 图片输出价格(向后兼容)
Intervals
[]
PricingInterval
// 区间定价列表
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
}
// PricingInterval 定价区间(token 区间 / 按次分层 / 图片分辨率分层)
type
PricingInterval
struct
{
ID
int64
PricingID
int64
MinTokens
int
// 区间下界(含)
MaxTokens
*
int
// 区间上界(不含),nil = 无上限
TierLabel
string
// 层级标签(按次/图片模式:1K, 2K, 4K, HD 等)
InputPrice
*
float64
// token 模式:每 token 输入价
OutputPrice
*
float64
// token 模式:每 token 输出价
CacheWritePrice
*
float64
// token 模式:缓存写入价
CacheReadPrice
*
float64
// token 模式:缓存读取价
PerRequestPrice
*
float64
// 按次/图片模式:每次请求价格
SortOrder
int
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
}
// IsActive 判断渠道是否启用
func
(
c
*
Channel
)
IsActive
()
bool
{
return
c
.
Status
==
StatusActive
}
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 优先精确匹配,然后通配符匹配(如 claude-opus-*)。大小写不敏感。
// 返回值拷贝,不污染缓存。
func
(
c
*
Channel
)
GetModelPricing
(
model
string
)
*
ChannelModelPricing
{
modelLower
:=
strings
.
ToLower
(
model
)
// 第一轮:精确匹配
for
i
:=
range
c
.
ModelPricing
{
for
_
,
m
:=
range
c
.
ModelPricing
[
i
]
.
Models
{
if
strings
.
ToLower
(
m
)
==
modelLower
{
cp
:=
c
.
ModelPricing
[
i
]
.
Clone
()
return
&
cp
}
}
}
// 第二轮:通配符匹配(仅支持末尾 *)
for
i
:=
range
c
.
ModelPricing
{
for
_
,
m
:=
range
c
.
ModelPricing
[
i
]
.
Models
{
mLower
:=
strings
.
ToLower
(
m
)
if
strings
.
HasSuffix
(
mLower
,
"*"
)
{
prefix
:=
strings
.
TrimSuffix
(
mLower
,
"*"
)
if
strings
.
HasPrefix
(
modelLower
,
prefix
)
{
cp
:=
c
.
ModelPricing
[
i
]
.
Clone
()
return
&
cp
}
}
}
}
return
nil
}
// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。
// 通用辅助函数,供 GetIntervalForContext、ModelPricingResolver 等复用。
func
FindMatchingInterval
(
intervals
[]
PricingInterval
,
totalTokens
int
)
*
PricingInterval
{
for
i
:=
range
intervals
{
iv
:=
&
intervals
[
i
]
if
totalTokens
>=
iv
.
MinTokens
&&
(
iv
.
MaxTokens
==
nil
||
totalTokens
<
*
iv
.
MaxTokens
)
{
return
iv
}
}
return
nil
}
// GetIntervalForContext 根据总 context token 数查找匹配的区间。
func
(
p
*
ChannelModelPricing
)
GetIntervalForContext
(
totalTokens
int
)
*
PricingInterval
{
return
FindMatchingInterval
(
p
.
Intervals
,
totalTokens
)
}
// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式)
func
(
p
*
ChannelModelPricing
)
GetTierByLabel
(
label
string
)
*
PricingInterval
{
labelLower
:=
strings
.
ToLower
(
label
)
for
i
:=
range
p
.
Intervals
{
if
strings
.
ToLower
(
p
.
Intervals
[
i
]
.
TierLabel
)
==
labelLower
{
return
&
p
.
Intervals
[
i
]
}
}
return
nil
}
// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全)
func
(
p
ChannelModelPricing
)
Clone
()
ChannelModelPricing
{
cp
:=
p
if
p
.
Models
!=
nil
{
cp
.
Models
=
make
([]
string
,
len
(
p
.
Models
))
copy
(
cp
.
Models
,
p
.
Models
)
}
if
p
.
Intervals
!=
nil
{
cp
.
Intervals
=
make
([]
PricingInterval
,
len
(
p
.
Intervals
))
copy
(
cp
.
Intervals
,
p
.
Intervals
)
}
return
cp
}
// Clone 返回 Channel 的深拷贝
func
(
c
*
Channel
)
Clone
()
*
Channel
{
if
c
==
nil
{
return
nil
}
cp
:=
*
c
if
c
.
GroupIDs
!=
nil
{
cp
.
GroupIDs
=
make
([]
int64
,
len
(
c
.
GroupIDs
))
copy
(
cp
.
GroupIDs
,
c
.
GroupIDs
)
}
if
c
.
ModelPricing
!=
nil
{
cp
.
ModelPricing
=
make
([]
ChannelModelPricing
,
len
(
c
.
ModelPricing
))
for
i
:=
range
c
.
ModelPricing
{
cp
.
ModelPricing
[
i
]
=
c
.
ModelPricing
[
i
]
.
Clone
()
}
}
return
&
cp
}
backend/internal/service/channel_service.go
0 → 100644
View file @
91c9b8d0
package
service
import
(
"context"
"fmt"
"log/slog"
"sync/atomic"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"golang.org/x/sync/singleflight"
)
var
(
ErrChannelNotFound
=
infraerrors
.
NotFound
(
"CHANNEL_NOT_FOUND"
,
"channel not found"
)
ErrChannelExists
=
infraerrors
.
Conflict
(
"CHANNEL_EXISTS"
,
"channel name already exists"
)
ErrGroupAlreadyInChannel
=
infraerrors
.
Conflict
(
"GROUP_ALREADY_IN_CHANNEL"
,
"one or more groups already belong to another channel"
,
)
)
// ChannelRepository 渠道数据访问接口
type
ChannelRepository
interface
{
Create
(
ctx
context
.
Context
,
channel
*
Channel
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Channel
,
error
)
Update
(
ctx
context
.
Context
,
channel
*
Channel
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
search
string
)
([]
Channel
,
*
pagination
.
PaginationResult
,
error
)
ListAll
(
ctx
context
.
Context
)
([]
Channel
,
error
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
ExistsByNameExcluding
(
ctx
context
.
Context
,
name
string
,
excludeID
int64
)
(
bool
,
error
)
// 分组关联
GetGroupIDs
(
ctx
context
.
Context
,
channelID
int64
)
([]
int64
,
error
)
SetGroupIDs
(
ctx
context
.
Context
,
channelID
int64
,
groupIDs
[]
int64
)
error
GetChannelIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
GetGroupsInOtherChannels
(
ctx
context
.
Context
,
channelID
int64
,
groupIDs
[]
int64
)
([]
int64
,
error
)
// 模型定价
ListModelPricing
(
ctx
context
.
Context
,
channelID
int64
)
([]
ChannelModelPricing
,
error
)
CreateModelPricing
(
ctx
context
.
Context
,
pricing
*
ChannelModelPricing
)
error
UpdateModelPricing
(
ctx
context
.
Context
,
pricing
*
ChannelModelPricing
)
error
DeleteModelPricing
(
ctx
context
.
Context
,
id
int64
)
error
ReplaceModelPricing
(
ctx
context
.
Context
,
channelID
int64
,
pricingList
[]
ChannelModelPricing
)
error
}
// channelCache 渠道缓存快照
type
channelCache
struct
{
// byID: channelID -> *Channel(含 ModelPricing)
byID
map
[
int64
]
*
Channel
// byGroupID: groupID -> channelID
byGroupID
map
[
int64
]
int64
loadedAt
time
.
Time
}
const
(
channelCacheTTL
=
60
*
time
.
Second
channelErrorTTL
=
5
*
time
.
Second
// DB 错误时的短缓存
channelCacheDBTimeout
=
10
*
time
.
Second
)
// ChannelService 渠道管理服务
type
ChannelService
struct
{
repo
ChannelRepository
authCacheInvalidator
APIKeyAuthCacheInvalidator
cache
atomic
.
Value
// *channelCache
cacheSF
singleflight
.
Group
}
// NewChannelService 创建渠道服务实例
func
NewChannelService
(
repo
ChannelRepository
,
authCacheInvalidator
APIKeyAuthCacheInvalidator
)
*
ChannelService
{
s
:=
&
ChannelService
{
repo
:
repo
,
authCacheInvalidator
:
authCacheInvalidator
,
}
return
s
}
// loadCache 加载或返回缓存的渠道数据
func
(
s
*
ChannelService
)
loadCache
(
ctx
context
.
Context
)
(
*
channelCache
,
error
)
{
if
cached
,
ok
:=
s
.
cache
.
Load
()
.
(
*
channelCache
);
ok
{
if
time
.
Since
(
cached
.
loadedAt
)
<
channelCacheTTL
{
return
cached
,
nil
}
}
result
,
err
,
_
:=
s
.
cacheSF
.
Do
(
"channel_cache"
,
func
()
(
any
,
error
)
{
// 双重检查
if
cached
,
ok
:=
s
.
cache
.
Load
()
.
(
*
channelCache
);
ok
{
if
time
.
Since
(
cached
.
loadedAt
)
<
channelCacheTTL
{
return
cached
,
nil
}
}
return
s
.
buildCache
(
ctx
)
})
if
err
!=
nil
{
return
nil
,
err
}
return
result
.
(
*
channelCache
),
nil
}
// buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。
func
(
s
*
ChannelService
)
buildCache
(
ctx
context
.
Context
)
(
*
channelCache
,
error
)
{
// 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
WithoutCancel
(
ctx
),
channelCacheDBTimeout
)
defer
cancel
()
channels
,
err
:=
s
.
repo
.
ListAll
(
dbCtx
)
if
err
!=
nil
{
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
slog
.
Warn
(
"failed to build channel cache"
,
"error"
,
err
)
errorCache
:=
&
channelCache
{
byID
:
make
(
map
[
int64
]
*
Channel
),
byGroupID
:
make
(
map
[
int64
]
int64
),
loadedAt
:
time
.
Now
()
.
Add
(
channelCacheTTL
-
channelErrorTTL
),
// 使剩余 TTL = errorTTL
}
s
.
cache
.
Store
(
errorCache
)
return
nil
,
fmt
.
Errorf
(
"list all channels: %w"
,
err
)
}
cache
:=
&
channelCache
{
byID
:
make
(
map
[
int64
]
*
Channel
,
len
(
channels
)),
byGroupID
:
make
(
map
[
int64
]
int64
),
loadedAt
:
time
.
Now
(),
}
for
i
:=
range
channels
{
ch
:=
&
channels
[
i
]
cache
.
byID
[
ch
.
ID
]
=
ch
for
_
,
gid
:=
range
ch
.
GroupIDs
{
cache
.
byGroupID
[
gid
]
=
ch
.
ID
}
}
s
.
cache
.
Store
(
cache
)
return
cache
,
nil
}
// invalidateCache 使缓存失效,让下次读取时自然重建
func
(
s
*
ChannelService
)
invalidateCache
()
{
s
.
cache
.
Store
((
*
channelCache
)(
nil
))
s
.
cacheSF
.
Forget
(
"channel_cache"
)
}
// GetChannelForGroup 获取分组关联的渠道(热路径,从缓存读取)
// 返回深拷贝,不污染缓存。
func
(
s
*
ChannelService
)
GetChannelForGroup
(
ctx
context
.
Context
,
groupID
int64
)
(
*
Channel
,
error
)
{
cache
,
err
:=
s
.
loadCache
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
channelID
,
ok
:=
cache
.
byGroupID
[
groupID
]
if
!
ok
{
return
nil
,
nil
}
ch
,
ok
:=
cache
.
byID
[
channelID
]
if
!
ok
{
return
nil
,
nil
}
if
!
ch
.
IsActive
()
{
return
nil
,
nil
}
return
ch
.
Clone
(),
nil
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径)
func
(
s
*
ChannelService
)
GetChannelModelPricing
(
ctx
context
.
Context
,
groupID
int64
,
model
string
)
*
ChannelModelPricing
{
ch
,
err
:=
s
.
GetChannelForGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
slog
.
Warn
(
"failed to get channel for group"
,
"group_id"
,
groupID
,
"error"
,
err
)
return
nil
}
if
ch
==
nil
{
return
nil
}
return
ch
.
GetModelPricing
(
model
)
}
// --- CRUD ---
// Create 创建渠道
func
(
s
*
ChannelService
)
Create
(
ctx
context
.
Context
,
input
*
CreateChannelInput
)
(
*
Channel
,
error
)
{
exists
,
err
:=
s
.
repo
.
ExistsByName
(
ctx
,
input
.
Name
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check channel exists: %w"
,
err
)
}
if
exists
{
return
nil
,
ErrChannelExists
}
// 检查分组冲突
if
len
(
input
.
GroupIDs
)
>
0
{
conflicting
,
err
:=
s
.
repo
.
GetGroupsInOtherChannels
(
ctx
,
0
,
input
.
GroupIDs
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check group conflicts: %w"
,
err
)
}
if
len
(
conflicting
)
>
0
{
return
nil
,
ErrGroupAlreadyInChannel
}
}
channel
:=
&
Channel
{
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Status
:
StatusActive
,
GroupIDs
:
input
.
GroupIDs
,
ModelPricing
:
input
.
ModelPricing
,
}
if
err
:=
s
.
repo
.
Create
(
ctx
,
channel
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create channel: %w"
,
err
)
}
s
.
invalidateCache
()
return
s
.
repo
.
GetByID
(
ctx
,
channel
.
ID
)
}
// GetByID 获取渠道详情
func
(
s
*
ChannelService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Channel
,
error
)
{
return
s
.
repo
.
GetByID
(
ctx
,
id
)
}
// Update 更新渠道
func
(
s
*
ChannelService
)
Update
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateChannelInput
)
(
*
Channel
,
error
)
{
channel
,
err
:=
s
.
repo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get channel: %w"
,
err
)
}
if
input
.
Name
!=
""
&&
input
.
Name
!=
channel
.
Name
{
exists
,
err
:=
s
.
repo
.
ExistsByNameExcluding
(
ctx
,
input
.
Name
,
id
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check channel exists: %w"
,
err
)
}
if
exists
{
return
nil
,
ErrChannelExists
}
channel
.
Name
=
input
.
Name
}
if
input
.
Description
!=
nil
{
channel
.
Description
=
*
input
.
Description
}
if
input
.
Status
!=
""
{
channel
.
Status
=
input
.
Status
}
// 检查分组冲突
if
input
.
GroupIDs
!=
nil
{
conflicting
,
err
:=
s
.
repo
.
GetGroupsInOtherChannels
(
ctx
,
id
,
*
input
.
GroupIDs
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check group conflicts: %w"
,
err
)
}
if
len
(
conflicting
)
>
0
{
return
nil
,
ErrGroupAlreadyInChannel
}
channel
.
GroupIDs
=
*
input
.
GroupIDs
}
if
input
.
ModelPricing
!=
nil
{
channel
.
ModelPricing
=
*
input
.
ModelPricing
}
if
err
:=
s
.
repo
.
Update
(
ctx
,
channel
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update channel: %w"
,
err
)
}
s
.
invalidateCache
()
// 失效关联分组的 auth 缓存
if
s
.
authCacheInvalidator
!=
nil
{
groupIDs
,
err
:=
s
.
repo
.
GetGroupIDs
(
ctx
,
id
)
if
err
!=
nil
{
slog
.
Warn
(
"failed to get group IDs for cache invalidation"
,
"channel_id"
,
id
,
"error"
,
err
)
}
for
_
,
gid
:=
range
groupIDs
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
gid
)
}
}
return
s
.
repo
.
GetByID
(
ctx
,
id
)
}
// Delete 删除渠道
func
(
s
*
ChannelService
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
// 先获取关联分组用于失效缓存
groupIDs
,
err
:=
s
.
repo
.
GetGroupIDs
(
ctx
,
id
)
if
err
!=
nil
{
slog
.
Warn
(
"failed to get group IDs before delete"
,
"channel_id"
,
id
,
"error"
,
err
)
}
if
err
:=
s
.
repo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete channel: %w"
,
err
)
}
s
.
invalidateCache
()
if
s
.
authCacheInvalidator
!=
nil
{
for
_
,
gid
:=
range
groupIDs
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
gid
)
}
}
return
nil
}
// List 获取渠道列表
func
(
s
*
ChannelService
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
search
string
)
([]
Channel
,
*
pagination
.
PaginationResult
,
error
)
{
return
s
.
repo
.
List
(
ctx
,
params
,
status
,
search
)
}
// --- Input types ---
// CreateChannelInput 创建渠道输入
type
CreateChannelInput
struct
{
Name
string
Description
string
GroupIDs
[]
int64
ModelPricing
[]
ChannelModelPricing
}
// UpdateChannelInput 更新渠道输入
type
UpdateChannelInput
struct
{
Name
string
Description
*
string
Status
string
GroupIDs
*
[]
int64
ModelPricing
*
[]
ChannelModelPricing
}
backend/internal/service/channel_test.go
0 → 100644
View file @
91c9b8d0
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
channelTestPtrFloat64
(
v
float64
)
*
float64
{
return
&
v
}
func
channelTestPtrInt
(
v
int
)
*
int
{
return
&
v
}
func
TestGetModelPricing
(
t
*
testing
.
T
)
{
ch
:=
&
Channel
{
ModelPricing
:
[]
ChannelModelPricing
{
{
ID
:
1
,
Models
:
[]
string
{
"claude-sonnet-4"
},
BillingMode
:
BillingModeToken
,
InputPrice
:
channelTestPtrFloat64
(
3e-6
)},
{
ID
:
2
,
Models
:
[]
string
{
"claude-*"
},
BillingMode
:
BillingModeToken
,
InputPrice
:
channelTestPtrFloat64
(
5e-6
)},
{
ID
:
3
,
Models
:
[]
string
{
"gpt-5.1"
},
BillingMode
:
BillingModePerRequest
},
},
}
tests
:=
[]
struct
{
name
string
model
string
wantID
int64
wantNil
bool
}{
{
"exact match"
,
"claude-sonnet-4"
,
1
,
false
},
{
"case insensitive"
,
"Claude-Sonnet-4"
,
1
,
false
},
{
"wildcard match"
,
"claude-opus-4-20250514"
,
2
,
false
},
{
"exact takes priority over wildcard"
,
"claude-sonnet-4"
,
1
,
false
},
{
"not found"
,
"gemini-3.1-pro"
,
0
,
true
},
{
"per_request model"
,
"gpt-5.1"
,
3
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
ch
.
GetModelPricing
(
tt
.
model
)
if
tt
.
wantNil
{
require
.
Nil
(
t
,
result
)
return
}
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
tt
.
wantID
,
result
.
ID
)
})
}
}
func
TestGetModelPricing_ReturnsCopy
(
t
*
testing
.
T
)
{
ch
:=
&
Channel
{
ModelPricing
:
[]
ChannelModelPricing
{
{
ID
:
1
,
Models
:
[]
string
{
"claude-sonnet-4"
},
InputPrice
:
channelTestPtrFloat64
(
3e-6
)},
},
}
result
:=
ch
.
GetModelPricing
(
"claude-sonnet-4"
)
require
.
NotNil
(
t
,
result
)
// Modify the returned copy's slice — original should be unchanged
result
.
Models
=
append
(
result
.
Models
,
"hacked"
)
// Original should be unchanged
require
.
Equal
(
t
,
1
,
len
(
ch
.
ModelPricing
[
0
]
.
Models
))
}
func
TestGetModelPricing_EmptyPricing
(
t
*
testing
.
T
)
{
ch
:=
&
Channel
{
ModelPricing
:
nil
}
require
.
Nil
(
t
,
ch
.
GetModelPricing
(
"any-model"
))
ch2
:=
&
Channel
{
ModelPricing
:
[]
ChannelModelPricing
{}}
require
.
Nil
(
t
,
ch2
.
GetModelPricing
(
"any-model"
))
}
func
TestGetIntervalForContext
(
t
*
testing
.
T
)
{
p
:=
&
ChannelModelPricing
{
Intervals
:
[]
PricingInterval
{
{
MinTokens
:
0
,
MaxTokens
:
channelTestPtrInt
(
128000
),
InputPrice
:
channelTestPtrFloat64
(
1e-6
)},
{
MinTokens
:
128000
,
MaxTokens
:
nil
,
InputPrice
:
channelTestPtrFloat64
(
2e-6
)},
},
}
tests
:=
[]
struct
{
name
string
tokens
int
wantPrice
*
float64
wantNil
bool
}{
{
"first interval"
,
50000
,
channelTestPtrFloat64
(
1e-6
),
false
},
{
"boundary: at min of second"
,
128000
,
channelTestPtrFloat64
(
2e-6
),
false
},
{
"boundary: at max of first (exclusive)"
,
128000
,
channelTestPtrFloat64
(
2e-6
),
false
},
{
"unbounded interval"
,
500000
,
channelTestPtrFloat64
(
2e-6
),
false
},
{
"zero tokens"
,
0
,
channelTestPtrFloat64
(
1e-6
),
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
p
.
GetIntervalForContext
(
tt
.
tokens
)
if
tt
.
wantNil
{
require
.
Nil
(
t
,
result
)
return
}
require
.
NotNil
(
t
,
result
)
require
.
InDelta
(
t
,
*
tt
.
wantPrice
,
*
result
.
InputPrice
,
1e-12
)
})
}
}
func
TestGetIntervalForContext_NoMatch
(
t
*
testing
.
T
)
{
p
:=
&
ChannelModelPricing
{
Intervals
:
[]
PricingInterval
{
{
MinTokens
:
10000
,
MaxTokens
:
channelTestPtrInt
(
50000
)},
},
}
require
.
Nil
(
t
,
p
.
GetIntervalForContext
(
5000
))
require
.
Nil
(
t
,
p
.
GetIntervalForContext
(
50000
))
}
func
TestGetIntervalForContext_Empty
(
t
*
testing
.
T
)
{
p
:=
&
ChannelModelPricing
{
Intervals
:
nil
}
require
.
Nil
(
t
,
p
.
GetIntervalForContext
(
1000
))
}
func
TestGetTierByLabel
(
t
*
testing
.
T
)
{
p
:=
&
ChannelModelPricing
{
Intervals
:
[]
PricingInterval
{
{
TierLabel
:
"1K"
,
PerRequestPrice
:
channelTestPtrFloat64
(
0.04
)},
{
TierLabel
:
"2K"
,
PerRequestPrice
:
channelTestPtrFloat64
(
0.08
)},
{
TierLabel
:
"HD"
,
PerRequestPrice
:
channelTestPtrFloat64
(
0.12
)},
},
}
tests
:=
[]
struct
{
name
string
label
string
wantNil
bool
want
float64
}{
{
"exact match"
,
"1K"
,
false
,
0.04
},
{
"case insensitive"
,
"hd"
,
false
,
0.12
},
{
"not found"
,
"4K"
,
true
,
0
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
p
.
GetTierByLabel
(
tt
.
label
)
if
tt
.
wantNil
{
require
.
Nil
(
t
,
result
)
return
}
require
.
NotNil
(
t
,
result
)
require
.
InDelta
(
t
,
tt
.
want
,
*
result
.
PerRequestPrice
,
1e-12
)
})
}
}
func
TestGetTierByLabel_Empty
(
t
*
testing
.
T
)
{
p
:=
&
ChannelModelPricing
{
Intervals
:
nil
}
require
.
Nil
(
t
,
p
.
GetTierByLabel
(
"1K"
))
}
func
TestChannelClone
(
t
*
testing
.
T
)
{
original
:=
&
Channel
{
ID
:
1
,
Name
:
"test"
,
GroupIDs
:
[]
int64
{
10
,
20
},
ModelPricing
:
[]
ChannelModelPricing
{
{
ID
:
100
,
Models
:
[]
string
{
"model-a"
},
InputPrice
:
channelTestPtrFloat64
(
5e-6
),
},
},
}
cloned
:=
original
.
Clone
()
require
.
NotNil
(
t
,
cloned
)
require
.
Equal
(
t
,
original
.
ID
,
cloned
.
ID
)
require
.
Equal
(
t
,
original
.
Name
,
cloned
.
Name
)
// Modify clone slices — original should not change
cloned
.
GroupIDs
[
0
]
=
999
require
.
Equal
(
t
,
int64
(
10
),
original
.
GroupIDs
[
0
])
cloned
.
ModelPricing
[
0
]
.
Models
[
0
]
=
"hacked"
require
.
Equal
(
t
,
"model-a"
,
original
.
ModelPricing
[
0
]
.
Models
[
0
])
}
func
TestChannelClone_Nil
(
t
*
testing
.
T
)
{
var
ch
*
Channel
require
.
Nil
(
t
,
ch
.
Clone
())
}
func
TestChannelModelPricingClone
(
t
*
testing
.
T
)
{
original
:=
ChannelModelPricing
{
Models
:
[]
string
{
"a"
,
"b"
},
Intervals
:
[]
PricingInterval
{
{
MinTokens
:
0
,
TierLabel
:
"tier1"
},
},
}
cloned
:=
original
.
Clone
()
// Modify clone slices — original unchanged
cloned
.
Models
[
0
]
=
"hacked"
require
.
Equal
(
t
,
"a"
,
original
.
Models
[
0
])
cloned
.
Intervals
[
0
]
.
TierLabel
=
"hacked"
require
.
Equal
(
t
,
"tier1"
,
original
.
Intervals
[
0
]
.
TierLabel
)
}
backend/internal/service/gateway_record_usage_test.go
View file @
91c9b8d0
...
...
@@ -41,6 +41,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil
,
nil
,
nil
,
nil
,
)
}
...
...
backend/internal/service/gateway_service.go
View file @
91c9b8d0
...
...
@@ -568,6 +568,7 @@ type GatewayService struct {
responseHeaderFilter
*
responseheaders
.
CompiledHeaderFilter
debugModelRouting
atomic
.
Bool
debugClaudeMimic
atomic
.
Bool
channelService
*
ChannelService
debugGatewayBodyFile
atomic
.
Pointer
[
os
.
File
]
// non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService
*
TLSFingerprintProfileService
}
...
...
@@ -597,6 +598,7 @@ func NewGatewayService(
digestStore
*
DigestSessionStore
,
settingService
*
SettingService
,
tlsFPProfileService
*
TLSFingerprintProfileService
,
channelService
*
ChannelService
,
)
*
GatewayService
{
userGroupRateTTL
:=
resolveUserGroupRateCacheTTL
(
cfg
)
modelsListTTL
:=
resolveModelsListCacheTTL
(
cfg
)
...
...
@@ -629,6 +631,7 @@ func NewGatewayService(
modelsListCacheTTL
:
modelsListTTL
,
responseHeaderFilter
:
compileResponseHeaderFilter
(
cfg
),
tlsFPProfileService
:
tlsFPProfileService
,
channelService
:
channelService
,
}
svc
.
userGroupRateResolver
=
newUserGroupRateResolver
(
userGroupRateRepo
,
...
...
@@ -7771,7 +7774,16 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheCreation1hTokens
:
result
.
Usage
.
CacheCreation1hTokens
,
}
var
err
error
cost
,
err
=
s
.
billingService
.
CalculateCost
(
billingModel
,
tokens
,
multiplier
)
// 渠道定价覆盖
var
chPricing
*
ChannelModelPricing
if
s
.
channelService
!=
nil
&&
apiKey
.
Group
!=
nil
{
chPricing
=
s
.
channelService
.
GetChannelModelPricing
(
ctx
,
apiKey
.
Group
.
ID
,
billingModel
)
}
if
chPricing
!=
nil
{
cost
,
err
=
s
.
billingService
.
CalculateCostWithChannel
(
billingModel
,
tokens
,
multiplier
,
chPricing
)
}
else
{
cost
,
err
=
s
.
billingService
.
CalculateCost
(
billingModel
,
tokens
,
multiplier
)
}
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"Calculate cost failed: %v"
,
err
)
cost
=
&
CostBreakdown
{
ActualCost
:
0
}
...
...
@@ -7959,7 +7971,16 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheCreation1hTokens
:
result
.
Usage
.
CacheCreation1hTokens
,
}
var
err
error
cost
,
err
=
s
.
billingService
.
CalculateCostWithLongContext
(
billingModel
,
tokens
,
multiplier
,
input
.
LongContextThreshold
,
input
.
LongContextMultiplier
)
// 渠道定价覆盖
var
chPricing2
*
ChannelModelPricing
if
s
.
channelService
!=
nil
&&
apiKey
.
Group
!=
nil
{
chPricing2
=
s
.
channelService
.
GetChannelModelPricing
(
ctx
,
apiKey
.
Group
.
ID
,
billingModel
)
}
if
chPricing2
!=
nil
{
cost
,
err
=
s
.
billingService
.
CalculateCostWithChannel
(
billingModel
,
tokens
,
multiplier
,
chPricing2
)
}
else
{
cost
,
err
=
s
.
billingService
.
CalculateCostWithLongContext
(
billingModel
,
tokens
,
multiplier
,
input
.
LongContextThreshold
,
input
.
LongContextMultiplier
)
}
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"Calculate cost failed: %v"
,
err
)
cost
=
&
CostBreakdown
{
ActualCost
:
0
}
...
...
backend/internal/service/model_pricing_resolver.go
0 → 100644
View file @
91c9b8d0
package
service
import
(
"context"
"log/slog"
)
// ResolvedPricing 统一定价解析结果
type
ResolvedPricing
struct
{
// Mode 计费模式
Mode
BillingMode
// Token 模式:基础定价(来自 LiteLLM 或 fallback)
BasePricing
*
ModelPricing
// Token 模式:区间定价列表(如有,覆盖 BasePricing 中的对应字段)
Intervals
[]
PricingInterval
// 按次/图片模式:分层定价
RequestTiers
[]
PricingInterval
// 来源标识
Source
string
// "channel", "litellm", "fallback"
// 是否支持缓存细分
SupportsCacheBreakdown
bool
}
// ModelPricingResolver 统一模型定价解析器。
// 解析链:Channel → LiteLLM → Fallback。
type
ModelPricingResolver
struct
{
channelService
*
ChannelService
billingService
*
BillingService
}
// NewModelPricingResolver 创建定价解析器实例
func
NewModelPricingResolver
(
channelService
*
ChannelService
,
billingService
*
BillingService
)
*
ModelPricingResolver
{
return
&
ModelPricingResolver
{
channelService
:
channelService
,
billingService
:
billingService
,
}
}
// PricingInput 定价解析输入
type
PricingInput
struct
{
Model
string
GroupID
*
int64
// nil 表示不检查渠道
}
// Resolve 解析模型定价。
// 1. 获取基础定价(LiteLLM → Fallback)
// 2. 如果指定了 GroupID,查找渠道定价并覆盖
func
(
r
*
ModelPricingResolver
)
Resolve
(
ctx
context
.
Context
,
input
PricingInput
)
*
ResolvedPricing
{
// 1. 获取基础定价
basePricing
,
source
:=
r
.
resolveBasePricing
(
input
.
Model
)
resolved
:=
&
ResolvedPricing
{
Mode
:
BillingModeToken
,
BasePricing
:
basePricing
,
Source
:
source
,
SupportsCacheBreakdown
:
basePricing
!=
nil
&&
basePricing
.
SupportsCacheBreakdown
,
}
// 2. 如果有 GroupID,尝试渠道覆盖
if
input
.
GroupID
!=
nil
{
r
.
applyChannelOverrides
(
ctx
,
*
input
.
GroupID
,
input
.
Model
,
resolved
)
}
return
resolved
}
// resolveBasePricing 从 LiteLLM 或 Fallback 获取基础定价
func
(
r
*
ModelPricingResolver
)
resolveBasePricing
(
model
string
)
(
*
ModelPricing
,
string
)
{
pricing
,
err
:=
r
.
billingService
.
GetModelPricing
(
model
)
if
err
!=
nil
{
slog
.
Debug
(
"failed to get model pricing from LiteLLM, using fallback"
,
"model"
,
model
,
"error"
,
err
)
return
nil
,
"fallback"
}
return
pricing
,
"litellm"
}
// applyChannelOverrides 应用渠道定价覆盖
func
(
r
*
ModelPricingResolver
)
applyChannelOverrides
(
ctx
context
.
Context
,
groupID
int64
,
model
string
,
resolved
*
ResolvedPricing
)
{
chPricing
:=
r
.
channelService
.
GetChannelModelPricing
(
ctx
,
groupID
,
model
)
if
chPricing
==
nil
{
return
}
resolved
.
Source
=
"channel"
resolved
.
Mode
=
chPricing
.
BillingMode
if
resolved
.
Mode
==
""
{
resolved
.
Mode
=
BillingModeToken
}
switch
resolved
.
Mode
{
case
BillingModeToken
:
r
.
applyTokenOverrides
(
chPricing
,
resolved
)
case
BillingModePerRequest
,
BillingModeImage
:
r
.
applyRequestTierOverrides
(
chPricing
,
resolved
)
}
}
// applyTokenOverrides 应用 token 模式的渠道覆盖
func
(
r
*
ModelPricingResolver
)
applyTokenOverrides
(
chPricing
*
ChannelModelPricing
,
resolved
*
ResolvedPricing
)
{
// 如果有区间定价,使用区间
if
len
(
chPricing
.
Intervals
)
>
0
{
resolved
.
Intervals
=
chPricing
.
Intervals
return
}
// 否则用 flat 字段覆盖 BasePricing
if
resolved
.
BasePricing
==
nil
{
resolved
.
BasePricing
=
&
ModelPricing
{}
}
if
chPricing
.
InputPrice
!=
nil
{
resolved
.
BasePricing
.
InputPricePerToken
=
*
chPricing
.
InputPrice
resolved
.
BasePricing
.
InputPricePerTokenPriority
=
*
chPricing
.
InputPrice
}
if
chPricing
.
OutputPrice
!=
nil
{
resolved
.
BasePricing
.
OutputPricePerToken
=
*
chPricing
.
OutputPrice
resolved
.
BasePricing
.
OutputPricePerTokenPriority
=
*
chPricing
.
OutputPrice
}
if
chPricing
.
CacheWritePrice
!=
nil
{
resolved
.
BasePricing
.
CacheCreationPricePerToken
=
*
chPricing
.
CacheWritePrice
resolved
.
BasePricing
.
CacheCreation5mPrice
=
*
chPricing
.
CacheWritePrice
resolved
.
BasePricing
.
CacheCreation1hPrice
=
*
chPricing
.
CacheWritePrice
}
if
chPricing
.
CacheReadPrice
!=
nil
{
resolved
.
BasePricing
.
CacheReadPricePerToken
=
*
chPricing
.
CacheReadPrice
resolved
.
BasePricing
.
CacheReadPricePerTokenPriority
=
*
chPricing
.
CacheReadPrice
}
}
// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖
func
(
r
*
ModelPricingResolver
)
applyRequestTierOverrides
(
chPricing
*
ChannelModelPricing
,
resolved
*
ResolvedPricing
)
{
resolved
.
RequestTiers
=
chPricing
.
Intervals
}
// GetIntervalPricing 根据 context token 数获取区间定价。
// 如果有区间列表,找到匹配区间并构造 ModelPricing;否则直接返回 BasePricing。
func
(
r
*
ModelPricingResolver
)
GetIntervalPricing
(
resolved
*
ResolvedPricing
,
totalContextTokens
int
)
*
ModelPricing
{
if
len
(
resolved
.
Intervals
)
==
0
{
return
resolved
.
BasePricing
}
iv
:=
FindMatchingInterval
(
resolved
.
Intervals
,
totalContextTokens
)
if
iv
==
nil
{
return
resolved
.
BasePricing
}
return
intervalToModelPricing
(
iv
,
resolved
.
SupportsCacheBreakdown
)
}
// intervalToModelPricing 将区间定价转换为 ModelPricing
func
intervalToModelPricing
(
iv
*
PricingInterval
,
supportsCacheBreakdown
bool
)
*
ModelPricing
{
pricing
:=
&
ModelPricing
{
SupportsCacheBreakdown
:
supportsCacheBreakdown
,
}
if
iv
.
InputPrice
!=
nil
{
pricing
.
InputPricePerToken
=
*
iv
.
InputPrice
pricing
.
InputPricePerTokenPriority
=
*
iv
.
InputPrice
}
if
iv
.
OutputPrice
!=
nil
{
pricing
.
OutputPricePerToken
=
*
iv
.
OutputPrice
pricing
.
OutputPricePerTokenPriority
=
*
iv
.
OutputPrice
}
if
iv
.
CacheWritePrice
!=
nil
{
pricing
.
CacheCreationPricePerToken
=
*
iv
.
CacheWritePrice
pricing
.
CacheCreation5mPrice
=
*
iv
.
CacheWritePrice
pricing
.
CacheCreation1hPrice
=
*
iv
.
CacheWritePrice
}
if
iv
.
CacheReadPrice
!=
nil
{
pricing
.
CacheReadPricePerToken
=
*
iv
.
CacheReadPrice
pricing
.
CacheReadPricePerTokenPriority
=
*
iv
.
CacheReadPrice
}
return
pricing
}
// GetRequestTierPrice 根据层级标签获取按次价格
func
(
r
*
ModelPricingResolver
)
GetRequestTierPrice
(
resolved
*
ResolvedPricing
,
tierLabel
string
)
float64
{
for
_
,
tier
:=
range
resolved
.
RequestTiers
{
if
tier
.
TierLabel
==
tierLabel
&&
tier
.
PerRequestPrice
!=
nil
{
return
*
tier
.
PerRequestPrice
}
}
return
0
}
// GetRequestTierPriceByContext 根据 context token 数获取按次价格
func
(
r
*
ModelPricingResolver
)
GetRequestTierPriceByContext
(
resolved
*
ResolvedPricing
,
totalContextTokens
int
)
float64
{
iv
:=
FindMatchingInterval
(
resolved
.
RequestTiers
,
totalContextTokens
)
if
iv
!=
nil
&&
iv
.
PerRequestPrice
!=
nil
{
return
*
iv
.
PerRequestPrice
}
return
0
}
backend/internal/service/model_pricing_resolver_test.go
0 → 100644
View file @
91c9b8d0
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/stretchr/testify/require"
)
func
resolverPtrFloat64
(
v
float64
)
*
float64
{
return
&
v
}
func
resolverPtrInt
(
v
int
)
*
int
{
return
&
v
}
func
newTestBillingServiceForResolver
()
*
BillingService
{
bs
:=
&
BillingService
{
fallbackPrices
:
make
(
map
[
string
]
*
ModelPricing
),
}
bs
.
fallbackPrices
[
"claude-sonnet-4"
]
=
&
ModelPricing
{
InputPricePerToken
:
3e-6
,
OutputPricePerToken
:
15e-6
,
CacheCreationPricePerToken
:
3.75e-6
,
CacheReadPricePerToken
:
0.3e-6
,
SupportsCacheBreakdown
:
false
,
}
return
bs
}
func
TestResolve_NoGroupID
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceForResolver
()
r
:=
NewModelPricingResolver
(
&
ChannelService
{},
bs
)
resolved
:=
r
.
Resolve
(
context
.
Background
(),
PricingInput
{
Model
:
"claude-sonnet-4"
,
GroupID
:
nil
,
})
require
.
NotNil
(
t
,
resolved
)
require
.
Equal
(
t
,
BillingModeToken
,
resolved
.
Mode
)
require
.
NotNil
(
t
,
resolved
.
BasePricing
)
require
.
InDelta
(
t
,
3e-6
,
resolved
.
BasePricing
.
InputPricePerToken
,
1e-12
)
// BillingService.GetModelPricing uses fallback internally, but resolveBasePricing
// reports "litellm" when GetModelPricing succeeds (regardless of internal source)
require
.
Equal
(
t
,
"litellm"
,
resolved
.
Source
)
}
func
TestResolve_UnknownModel
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceForResolver
()
r
:=
NewModelPricingResolver
(
&
ChannelService
{},
bs
)
resolved
:=
r
.
Resolve
(
context
.
Background
(),
PricingInput
{
Model
:
"unknown-model-xyz"
,
GroupID
:
nil
,
})
require
.
NotNil
(
t
,
resolved
)
require
.
Nil
(
t
,
resolved
.
BasePricing
)
// Unknown model: GetModelPricing returns error, source is "fallback"
require
.
Equal
(
t
,
"fallback"
,
resolved
.
Source
)
}
func
TestGetIntervalPricing_NoIntervals
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceForResolver
()
r
:=
NewModelPricingResolver
(
&
ChannelService
{},
bs
)
basePricing
:=
&
ModelPricing
{
InputPricePerToken
:
5e-6
}
resolved
:=
&
ResolvedPricing
{
Mode
:
BillingModeToken
,
BasePricing
:
basePricing
,
Intervals
:
nil
,
}
result
:=
r
.
GetIntervalPricing
(
resolved
,
50000
)
require
.
Equal
(
t
,
basePricing
,
result
)
}
func
TestGetIntervalPricing_MatchesInterval
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceForResolver
()
r
:=
NewModelPricingResolver
(
&
ChannelService
{},
bs
)
resolved
:=
&
ResolvedPricing
{
Mode
:
BillingModeToken
,
BasePricing
:
&
ModelPricing
{
InputPricePerToken
:
5e-6
},
SupportsCacheBreakdown
:
true
,
Intervals
:
[]
PricingInterval
{
{
MinTokens
:
0
,
MaxTokens
:
resolverPtrInt
(
128000
),
InputPrice
:
resolverPtrFloat64
(
1e-6
),
OutputPrice
:
resolverPtrFloat64
(
2e-6
)},
{
MinTokens
:
128000
,
MaxTokens
:
nil
,
InputPrice
:
resolverPtrFloat64
(
3e-6
),
OutputPrice
:
resolverPtrFloat64
(
6e-6
)},
},
}
result
:=
r
.
GetIntervalPricing
(
resolved
,
50000
)
require
.
NotNil
(
t
,
result
)
require
.
InDelta
(
t
,
1e-6
,
result
.
InputPricePerToken
,
1e-12
)
require
.
InDelta
(
t
,
2e-6
,
result
.
OutputPricePerToken
,
1e-12
)
require
.
True
(
t
,
result
.
SupportsCacheBreakdown
)
result2
:=
r
.
GetIntervalPricing
(
resolved
,
200000
)
require
.
NotNil
(
t
,
result2
)
require
.
InDelta
(
t
,
3e-6
,
result2
.
InputPricePerToken
,
1e-12
)
}
func
TestGetIntervalPricing_NoMatch_FallsBackToBase
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceForResolver
()
r
:=
NewModelPricingResolver
(
&
ChannelService
{},
bs
)
basePricing
:=
&
ModelPricing
{
InputPricePerToken
:
99e-6
}
resolved
:=
&
ResolvedPricing
{
Mode
:
BillingModeToken
,
BasePricing
:
basePricing
,
Intervals
:
[]
PricingInterval
{
{
MinTokens
:
10000
,
MaxTokens
:
resolverPtrInt
(
50000
),
InputPrice
:
resolverPtrFloat64
(
1e-6
)},
},
}
result
:=
r
.
GetIntervalPricing
(
resolved
,
5000
)
require
.
Equal
(
t
,
basePricing
,
result
)
}
func
TestGetRequestTierPrice
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceForResolver
()
r
:=
NewModelPricingResolver
(
&
ChannelService
{},
bs
)
resolved
:=
&
ResolvedPricing
{
Mode
:
BillingModePerRequest
,
RequestTiers
:
[]
PricingInterval
{
{
TierLabel
:
"1K"
,
PerRequestPrice
:
resolverPtrFloat64
(
0.04
)},
{
TierLabel
:
"2K"
,
PerRequestPrice
:
resolverPtrFloat64
(
0.08
)},
},
}
require
.
InDelta
(
t
,
0.04
,
r
.
GetRequestTierPrice
(
resolved
,
"1K"
),
1e-12
)
require
.
InDelta
(
t
,
0.08
,
r
.
GetRequestTierPrice
(
resolved
,
"2K"
),
1e-12
)
require
.
InDelta
(
t
,
0.0
,
r
.
GetRequestTierPrice
(
resolved
,
"4K"
),
1e-12
)
}
func
TestGetRequestTierPriceByContext
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceForResolver
()
r
:=
NewModelPricingResolver
(
&
ChannelService
{},
bs
)
resolved
:=
&
ResolvedPricing
{
Mode
:
BillingModePerRequest
,
RequestTiers
:
[]
PricingInterval
{
{
MinTokens
:
0
,
MaxTokens
:
resolverPtrInt
(
128000
),
PerRequestPrice
:
resolverPtrFloat64
(
0.05
)},
{
MinTokens
:
128000
,
MaxTokens
:
nil
,
PerRequestPrice
:
resolverPtrFloat64
(
0.10
)},
},
}
require
.
InDelta
(
t
,
0.05
,
r
.
GetRequestTierPriceByContext
(
resolved
,
50000
),
1e-12
)
require
.
InDelta
(
t
,
0.10
,
r
.
GetRequestTierPriceByContext
(
resolved
,
200000
),
1e-12
)
}
func
TestGetRequestTierPrice_NilPerRequestPrice
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceForResolver
()
r
:=
NewModelPricingResolver
(
&
ChannelService
{},
bs
)
resolved
:=
&
ResolvedPricing
{
Mode
:
BillingModePerRequest
,
RequestTiers
:
[]
PricingInterval
{
{
TierLabel
:
"1K"
,
PerRequestPrice
:
nil
},
},
}
require
.
InDelta
(
t
,
0.0
,
r
.
GetRequestTierPrice
(
resolved
,
"1K"
),
1e-12
)
}
backend/internal/service/wire.go
View file @
91c9b8d0
...
...
@@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
ProvideScheduledTestService
,
ProvideScheduledTestRunnerService
,
NewGroupCapacityService
,
NewChannelService
,
NewModelPricingResolver
,
)
backend/migrations/081_create_channels.sql
0 → 100644
View file @
91c9b8d0
-- Create channels table for managing pricing channels.
-- A channel groups multiple groups together and provides custom model pricing.
SET
LOCAL
lock_timeout
=
'5s'
;
SET
LOCAL
statement_timeout
=
'10min'
;
-- 渠道表
CREATE
TABLE
IF
NOT
EXISTS
channels
(
id
BIGSERIAL
PRIMARY
KEY
,
name
VARCHAR
(
100
)
NOT
NULL
,
description
TEXT
DEFAULT
''
,
status
VARCHAR
(
20
)
NOT
NULL
DEFAULT
'active'
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
-- 渠道名称唯一索引
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
idx_channels_name
ON
channels
(
name
);
CREATE
INDEX
IF
NOT
EXISTS
idx_channels_status
ON
channels
(
status
);
-- 渠道-分组关联表(每个分组只能属于一个渠道)
CREATE
TABLE
IF
NOT
EXISTS
channel_groups
(
id
BIGSERIAL
PRIMARY
KEY
,
channel_id
BIGINT
NOT
NULL
REFERENCES
channels
(
id
)
ON
DELETE
CASCADE
,
group_id
BIGINT
NOT
NULL
REFERENCES
groups
(
id
)
ON
DELETE
CASCADE
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
idx_channel_groups_group_id
ON
channel_groups
(
group_id
);
CREATE
INDEX
IF
NOT
EXISTS
idx_channel_groups_channel_id
ON
channel_groups
(
channel_id
);
-- 渠道模型定价表(一条定价可绑定多个模型)
CREATE
TABLE
IF
NOT
EXISTS
channel_model_pricing
(
id
BIGSERIAL
PRIMARY
KEY
,
channel_id
BIGINT
NOT
NULL
REFERENCES
channels
(
id
)
ON
DELETE
CASCADE
,
models
JSONB
NOT
NULL
DEFAULT
'[]'
,
input_price
NUMERIC
(
20
,
12
),
output_price
NUMERIC
(
20
,
12
),
cache_write_price
NUMERIC
(
20
,
12
),
cache_read_price
NUMERIC
(
20
,
12
),
image_output_price
NUMERIC
(
20
,
8
),
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
INDEX
IF
NOT
EXISTS
idx_channel_model_pricing_channel_id
ON
channel_model_pricing
(
channel_id
);
COMMENT
ON
TABLE
channels
IS
'渠道管理:关联多个分组,提供自定义模型定价'
;
COMMENT
ON
TABLE
channel_groups
IS
'渠道-分组关联表:每个分组最多属于一个渠道'
;
COMMENT
ON
TABLE
channel_model_pricing
IS
'渠道模型定价:一条定价可绑定多个模型,价格一致'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
models
IS
'绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
input_price
IS
'每 token 输入价格(USD),NULL 表示使用默认'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
output_price
IS
'每 token 输出价格(USD),NULL 表示使用默认'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
cache_write_price
IS
'缓存写入每 token 价格,NULL 表示使用默认'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
cache_read_price
IS
'缓存读取每 token 价格,NULL 表示使用默认'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
image_output_price
IS
'图片输出价格(Gemini Image 等),NULL 表示使用默认'
;
backend/migrations/082_refactor_channel_pricing.sql
0 → 100644
View file @
91c9b8d0
-- Extend channel_model_pricing with billing_mode and add context-interval child table.
-- Supports three billing modes: token (per-token with context intervals),
-- per_request (per-request with context-size tiers), and image (per-image).
SET
LOCAL
lock_timeout
=
'5s'
;
SET
LOCAL
statement_timeout
=
'10min'
;
-- 1. 为 channel_model_pricing 添加 billing_mode 列
ALTER
TABLE
channel_model_pricing
ADD
COLUMN
IF
NOT
EXISTS
billing_mode
VARCHAR
(
20
)
NOT
NULL
DEFAULT
'token'
;
COMMENT
ON
COLUMN
channel_model_pricing
.
billing_mode
IS
'计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)'
;
-- 2. 创建区间定价子表
CREATE
TABLE
IF
NOT
EXISTS
channel_pricing_intervals
(
id
BIGSERIAL
PRIMARY
KEY
,
pricing_id
BIGINT
NOT
NULL
REFERENCES
channel_model_pricing
(
id
)
ON
DELETE
CASCADE
,
min_tokens
INT
NOT
NULL
DEFAULT
0
,
max_tokens
INT
,
tier_label
VARCHAR
(
50
),
input_price
NUMERIC
(
20
,
12
),
output_price
NUMERIC
(
20
,
12
),
cache_write_price
NUMERIC
(
20
,
12
),
cache_read_price
NUMERIC
(
20
,
12
),
per_request_price
NUMERIC
(
20
,
12
),
sort_order
INT
NOT
NULL
DEFAULT
0
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
updated_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
INDEX
IF
NOT
EXISTS
idx_channel_pricing_intervals_pricing_id
ON
channel_pricing_intervals
(
pricing_id
);
COMMENT
ON
TABLE
channel_pricing_intervals
IS
'渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
min_tokens
IS
'区间下界(含),token 模式使用'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
max_tokens
IS
'区间上界(不含),NULL 表示无上限'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
tier_label
IS
'层级标签,按次/图片模式使用(如 1K、2K、4K、HD)'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
input_price
IS
'token 模式:每 token 输入价'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
output_price
IS
'token 模式:每 token 输出价'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
cache_write_price
IS
'token 模式:缓存写入价'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
cache_read_price
IS
'token 模式:缓存读取价'
;
COMMENT
ON
COLUMN
channel_pricing_intervals
.
per_request_price
IS
'按次/图片模式:每次请求价格'
;
-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目
INSERT
INTO
channel_pricing_intervals
(
pricing_id
,
min_tokens
,
max_tokens
,
input_price
,
output_price
,
cache_write_price
,
cache_read_price
,
sort_order
)
SELECT
cmp
.
id
,
0
,
NULL
,
cmp
.
input_price
,
cmp
.
output_price
,
cmp
.
cache_write_price
,
cmp
.
cache_read_price
,
0
FROM
channel_model_pricing
cmp
WHERE
cmp
.
billing_mode
=
'token'
AND
(
cmp
.
input_price
IS
NOT
NULL
OR
cmp
.
output_price
IS
NOT
NULL
OR
cmp
.
cache_write_price
IS
NOT
NULL
OR
cmp
.
cache_read_price
IS
NOT
NULL
)
AND
NOT
EXISTS
(
SELECT
1
FROM
channel_pricing_intervals
cpi
WHERE
cpi
.
pricing_id
=
cmp
.
id
);
-- 4. 迁移 image_output_price 为 image 模式的区间条目
-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留
-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理
frontend/src/api/admin/channels.ts
0 → 100644
View file @
91c9b8d0
/**
* Admin Channels API endpoints
* Handles channel management for administrators
*/
import
{
apiClient
}
from
'
../client
'
export
type
BillingMode
=
'
token
'
|
'
per_request
'
|
'
image
'
export
interface
PricingInterval
{
id
?:
number
min_tokens
:
number
max_tokens
:
number
|
null
tier_label
:
string
input_price
:
number
|
null
output_price
:
number
|
null
cache_write_price
:
number
|
null
cache_read_price
:
number
|
null
per_request_price
:
number
|
null
sort_order
:
number
}
export
interface
ChannelModelPricing
{
id
?:
number
models
:
string
[]
billing_mode
:
BillingMode
input_price
:
number
|
null
output_price
:
number
|
null
cache_write_price
:
number
|
null
cache_read_price
:
number
|
null
image_output_price
:
number
|
null
intervals
:
PricingInterval
[]
}
export
interface
Channel
{
id
:
number
name
:
string
description
:
string
status
:
string
group_ids
:
number
[]
model_pricing
:
ChannelModelPricing
[]
created_at
:
string
updated_at
:
string
}
export
interface
CreateChannelRequest
{
name
:
string
description
?:
string
group_ids
?:
number
[]
model_pricing
?:
ChannelModelPricing
[]
}
export
interface
UpdateChannelRequest
{
name
?:
string
description
?:
string
status
?:
string
group_ids
?:
number
[]
model_pricing
?:
ChannelModelPricing
[]
}
interface
PaginatedResponse
<
T
>
{
items
:
T
[]
total
:
number
}
/**
* List channels with pagination
*/
export
async
function
list
(
page
:
number
=
1
,
pageSize
:
number
=
20
,
filters
?:
{
status
?:
string
search
?:
string
},
options
?:
{
signal
?:
AbortSignal
}
):
Promise
<
PaginatedResponse
<
Channel
>>
{
const
{
data
}
=
await
apiClient
.
get
<
PaginatedResponse
<
Channel
>>
(
'
/admin/channels
'
,
{
params
:
{
page
,
page_size
:
pageSize
,
...
filters
},
signal
:
options
?.
signal
})
return
data
}
/**
* Get channel by ID
*/
export
async
function
getById
(
id
:
number
):
Promise
<
Channel
>
{
const
{
data
}
=
await
apiClient
.
get
<
Channel
>
(
`/admin/channels/
${
id
}
`
)
return
data
}
/**
* Create a new channel
*/
export
async
function
create
(
req
:
CreateChannelRequest
):
Promise
<
Channel
>
{
const
{
data
}
=
await
apiClient
.
post
<
Channel
>
(
'
/admin/channels
'
,
req
)
return
data
}
/**
* Update a channel
*/
export
async
function
update
(
id
:
number
,
req
:
UpdateChannelRequest
):
Promise
<
Channel
>
{
const
{
data
}
=
await
apiClient
.
put
<
Channel
>
(
`/admin/channels/
${
id
}
`
,
req
)
return
data
}
/**
* Delete a channel
*/
export
async
function
remove
(
id
:
number
):
Promise
<
void
>
{
await
apiClient
.
delete
(
`/admin/channels/
${
id
}
`
)
}
const
channelsAPI
=
{
list
,
getById
,
create
,
update
,
remove
}
export
default
channelsAPI
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment