Unverified Commit 1cf51b14 authored by 程序猿MT's avatar 程序猿MT Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents 174d7c77 a817cafe
...@@ -32,7 +32,7 @@ jobs: ...@@ -32,7 +32,7 @@ jobs:
working-directory: backend working-directory: backend
run: | run: |
go install github.com/securego/gosec/v2/cmd/gosec@latest go install github.com/securego/gosec/v2/cmd/gosec@latest
gosec -severity high -confidence high ./... gosec -conf .gosec.json -severity high -confidence high ./...
frontend-security: frontend-security:
runs-on: ubuntu-latest runs-on: ubuntu-latest
......
{
"global": {
"exclude": "G704"
}
}
...@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc ...@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
pageSize := dataPageCap pageSize := dataPageCap
var out []service.Account var out []service.Account
for { for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search) items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -157,7 +157,12 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -157,7 +157,12 @@ func (h *AccountHandler) List(c *gin.Context) {
search = search[:100] search = search[:100]
} }
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) var groupID int64
if groupIDStr := c.Query("group"); groupIDStr != "" {
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -1413,7 +1418,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { ...@@ -1413,7 +1418,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0) accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 { if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "") allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p ...@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
return s.apiKeys, int64(len(s.apiKeys)), nil return s.apiKeys, int64(len(s.apiKeys)), nil
} }
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) { func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil return s.accounts, int64(len(s.accounts)), nil
} }
......
...@@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { ...@@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
} }
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "") return r.ListWithFilters(ctx, params, "", "", "", "", 0)
} }
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
q := r.client.Account.Query() q := r.client.Account.Query()
if platform != "" { if platform != "" {
...@@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati ...@@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
if search != "" { if search != "" {
q = q.Where(dbaccount.NameContainsFold(search)) q = q.Where(dbaccount.NameContainsFold(search))
} }
if groupID > 0 {
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
}
total, err := q.Count(ctx) total, err := q.Count(ctx)
if err != nil { if err != nil {
......
...@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
tt.setup(client) tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search) accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount) s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil { if tt.validate != nil {
...@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { ...@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID) s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc") accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)
......
...@@ -936,7 +936,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination ...@@ -936,7 +936,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
......
...@@ -32,7 +32,7 @@ type AccountRepository interface { ...@@ -32,7 +32,7 @@ type AccountRepository interface {
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error) ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error)
......
...@@ -75,7 +75,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination ...@@ -75,7 +75,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
panic("unexpected List call") panic("unexpected List call")
} }
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") panic("unexpected ListWithFilters call")
} }
......
...@@ -39,7 +39,7 @@ type AdminService interface { ...@@ -39,7 +39,7 @@ type AdminService interface {
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
...@@ -1021,9 +1021,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates [] ...@@ -1021,9 +1021,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
......
...@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct { ...@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct {
listWithFiltersErr error listWithFiltersErr error
} }
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++ s.listWithFiltersCalls++
s.listWithFiltersParams = params s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform s.listWithFiltersPlatform = platform
...@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { ...@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc") accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(10), total) require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
......
...@@ -4119,6 +4119,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs ...@@ -4119,6 +4119,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v) usage.CacheCreationInputTokens = int(v)
} }
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
if cc, ok := u["cache_creation"].(map[string]any); ok {
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
usage.CacheCreation5mTokens = int(v)
}
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
usage.CacheCreation1hTokens = int(v)
}
}
} }
// extractClaudeUsage 从非流式 Claude 响应提取 usage // extractClaudeUsage 从非流式 Claude 响应提取 usage
...@@ -4141,6 +4150,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage ...@@ -4141,6 +4150,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
if v, ok := u["cache_creation_input_tokens"].(float64); ok { if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v) usage.CacheCreationInputTokens = int(v)
} }
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
if cc, ok := u["cache_creation"].(map[string]any); ok {
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
usage.CacheCreation5mTokens = int(v)
}
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
usage.CacheCreation1hTokens = int(v)
}
}
} }
return usage return usage
} }
...@@ -31,8 +31,8 @@ type ModelPricing struct { ...@@ -31,8 +31,8 @@ type ModelPricing struct {
OutputPricePerToken float64 // 每token输出价格 (USD) OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退 CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退 CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类 SupportsCacheBreakdown bool // 是否支持详细的缓存分类
} }
...@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { ...@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
if s.pricingService != nil { if s.pricingService != nil {
litellmPricing := s.pricingService.GetModelPricing(model) litellmPricing := s.pricingService.GetModelPricing(model)
if litellmPricing != nil { if litellmPricing != nil {
// 启用 5m/1h 分类计费的条件:
// 1. 存在 1h 价格
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
price5m := litellmPricing.CacheCreationInputTokenCost
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
enableBreakdown := price1h > 0 && price1h > price5m
return &ModelPricing{ return &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken, InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken, OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
SupportsCacheBreakdown: false, CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
}, nil }, nil
} }
} }
...@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul ...@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
// 计算缓存费用 // 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
// 支持详细缓存分类的模型(5分钟/1小时缓存) // 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice + if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
} else {
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
}
} else { } else {
// 标准缓存创建价格(per-token) // 标准缓存创建价格(per-token)
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
...@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage ...@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
// 范围内部分:正常计费 // 范围内部分:正常计费
inRangeTokens := UsageTokens{ inRangeTokens := UsageTokens{
InputTokens: inRangeInputTokens, InputTokens: inRangeInputTokens,
OutputTokens: tokens.OutputTokens, // 输出只算一次 OutputTokens: tokens.OutputTokens, // 输出只算一次
CacheCreationTokens: tokens.CacheCreationTokens, CacheCreationTokens: tokens.CacheCreationTokens,
CacheReadTokens: inRangeCacheTokens, CacheReadTokens: inRangeCacheTokens,
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
} }
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil { if err != nil {
......
...@@ -87,7 +87,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error ...@@ -87,7 +87,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
......
...@@ -349,6 +349,8 @@ type ClaudeUsage struct { ...@@ -349,6 +349,8 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"`
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
} }
// ForwardResult 转发结果 // ForwardResult 转发结果
...@@ -4403,6 +4405,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { ...@@ -4403,6 +4405,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
usage.InputTokens = msgStart.Message.Usage.InputTokens usage.InputTokens = msgStart.Message.Usage.InputTokens
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
if cc5m.Exists() || cc1h.Exists() {
usage.CacheCreation5mTokens = int(cc5m.Int())
usage.CacheCreation1hTokens = int(cc1h.Int())
}
} }
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
...@@ -4431,6 +4441,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { ...@@ -4431,6 +4441,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
if msgDelta.Usage.CacheReadInputTokens > 0 { if msgDelta.Usage.CacheReadInputTokens > 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
} }
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
if cc5m.Exists() || cc1h.Exists() {
usage.CacheCreation5mTokens = int(cc5m.Int())
usage.CacheCreation1hTokens = int(cc1h.Int())
}
} }
} }
...@@ -4451,6 +4469,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h ...@@ -4451,6 +4469,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
return nil, fmt.Errorf("parse response: %w", err) return nil, fmt.Errorf("parse response: %w", err)
} }
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens")
cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens")
if cc5m.Exists() || cc1h.Exists() {
response.Usage.CacheCreation5mTokens = int(cc5m.Int())
response.Usage.CacheCreation1hTokens = int(cc1h.Int())
}
// 兼容 Kimi cached_tokens → cache_read_input_tokens // 兼容 Kimi cached_tokens → cache_read_input_tokens
if response.Usage.CacheReadInputTokens == 0 { if response.Usage.CacheReadInputTokens == 0 {
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
...@@ -4560,10 +4586,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -4560,10 +4586,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} else { } else {
// Token 计费 // Token 计费
tokens := UsageTokens{ tokens := UsageTokens{
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
...@@ -4597,6 +4625,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -4597,6 +4625,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost, InputCost: cost.InputCost,
OutputCost: cost.OutputCost, OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost, CacheCreationCost: cost.CacheCreationCost,
...@@ -4741,10 +4771,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -4741,10 +4771,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} else { } else {
// Token 计费(使用长上下文计费方法) // Token 计费(使用长上下文计费方法)
tokens := UsageTokens{ tokens := UsageTokens{
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
...@@ -4778,6 +4810,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -4778,6 +4810,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost, InputCost: cost.InputCost,
OutputCost: cost.OutputCost, OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost, CacheCreationCost: cost.CacheCreationCost,
......
...@@ -74,7 +74,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error ...@@ -74,7 +74,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
......
...@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s ...@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page, Page: page,
PageSize: opsAccountsPageSize, PageSize: opsAccountsPageSize,
}, platformFilter, "", "", "") }, platformFilter, "", "", "", 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -27,14 +27,15 @@ var ( ...@@ -27,14 +27,15 @@ var (
// LiteLLMModelPricing LiteLLM价格数据结构 // LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值 // 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct { type LiteLLMModelPricing struct {
InputCostPerToken float64 `json:"input_cost_per_token"` InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"` OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
LiteLLMProvider string `json:"litellm_provider"` CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
Mode string `json:"mode"` LiteLLMProvider string `json:"litellm_provider"`
SupportsPromptCaching bool `json:"supports_prompt_caching"` Mode string `json:"mode"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
} }
// PricingRemoteClient 远程价格数据获取接口 // PricingRemoteClient 远程价格数据获取接口
...@@ -45,14 +46,15 @@ type PricingRemoteClient interface { ...@@ -45,14 +46,15 @@ type PricingRemoteClient interface {
// LiteLLMRawEntry 用于解析原始JSON数据 // LiteLLMRawEntry 用于解析原始JSON数据
type LiteLLMRawEntry struct { type LiteLLMRawEntry struct {
InputCostPerToken *float64 `json:"input_cost_per_token"` InputCostPerToken *float64 `json:"input_cost_per_token"`
OutputCostPerToken *float64 `json:"output_cost_per_token"` OutputCostPerToken *float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
LiteLLMProvider string `json:"litellm_provider"` CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
Mode string `json:"mode"` LiteLLMProvider string `json:"litellm_provider"`
SupportsPromptCaching bool `json:"supports_prompt_caching"` Mode string `json:"mode"`
OutputCostPerImage *float64 `json:"output_cost_per_image"` SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage *float64 `json:"output_cost_per_image"`
} }
// PricingService 动态价格服务 // PricingService 动态价格服务
...@@ -318,6 +320,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel ...@@ -318,6 +320,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if entry.CacheCreationInputTokenCost != nil { if entry.CacheCreationInputTokenCost != nil {
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
} }
if entry.CacheCreationInputTokenCostAbove1hr != nil {
pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr
}
if entry.CacheReadInputTokenCost != nil { if entry.CacheReadInputTokenCost != nil {
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
} }
......
...@@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head ...@@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
} }
} }
// 2. 尝试从响应头解析重置时间(Anthropic) // 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
if result := calculateAnthropic429ResetTime(headers); result != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
// 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推
windowEnd := result.resetAt
if result.fiveHourReset != nil {
windowEnd = *result.fiveHourReset
}
windowStart := windowEnd.Add(-5 * time.Hour)
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
}
slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second))
return
}
// 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容)
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) // 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
if resetTimestamp == "" { if resetTimestamp == "" {
switch account.Platform { switch account.Platform {
case PlatformOpenAI: case PlatformOpenAI:
...@@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim ...@@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
return nil return nil
} }
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
type anthropic429Result struct {
resetAt time.Time // The correct reset time to use for SetRateLimited
fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available
}
// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers
// to determine which window (5h or 7d) actually triggered the 429.
//
// Headers used:
// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold
// - anthropic-ratelimit-unified-5h-reset
// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold
// - anthropic-ratelimit-unified-7d-reset
//
// Returns nil when the per-window headers are absent (caller should fall back to
// the aggregated anthropic-ratelimit-unified-reset header).
func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result {
reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset")
reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset")
if reset5hStr == "" && reset7dStr == "" {
return nil
}
var reset5h, reset7d *time.Time
if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil {
t := time.Unix(ts, 0)
reset5h = &t
}
if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil {
t := time.Unix(ts, 0)
reset7d = &t
}
is5hExceeded := isAnthropicWindowExceeded(headers, "5h")
is7dExceeded := isAnthropicWindowExceeded(headers, "7d")
slog.Info("anthropic_429_window_analysis",
"is_5h_exceeded", is5hExceeded,
"is_7d_exceeded", is7dExceeded,
"reset_5h", reset5hStr,
"reset_7d", reset7dStr,
)
// Select the correct reset time based on which window(s) are exceeded.
var chosen *time.Time
switch {
case is5hExceeded && is7dExceeded:
// Both exceeded → prefer 7d (longer cooldown), fall back to 5h
chosen = reset7d
if chosen == nil {
chosen = reset5h
}
case is5hExceeded:
chosen = reset5h
case is7dExceeded:
chosen = reset7d
default:
// Neither flag clearly exceeded — pick the sooner reset as best guess
chosen = pickSooner(reset5h, reset7d)
}
if chosen == nil {
return nil
}
return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h}
}
// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window
// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers.
func isAnthropicWindowExceeded(headers http.Header, window string) bool {
prefix := "anthropic-ratelimit-unified-" + window + "-"
// Check surpassed-threshold first (most explicit signal)
if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") {
return true
}
// Fall back to utilization >= 1.0
if utilStr := headers.Get(prefix + "utilization"); utilStr != "" {
if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 {
// Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0
return true
}
}
return false
}
// pickSooner returns whichever of the two time pointers is earlier.
// If only one is non-nil, it is returned. If both are nil, returns nil.
func pickSooner(a, b *time.Time) *time.Time {
switch {
case a != nil && b != nil:
if a.Before(*b) {
return a
}
return b
case a != nil:
return a
default:
return b
}
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 // parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式: // OpenAI 的 usage_limit_reached 错误格式:
// //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment