Commit 17ae51c0 authored by yangjianbo's avatar yangjianbo
Browse files

merge: 合并远程分支并修复代码冲突

合并了远程分支 cb72262a 的功能更新,同时保留了 ESLint 修复:

**冲突解决详情:**

1. AccountTableFilters.vue
   -  保留 emit 模式修复(避免 vue/no-mutating-props 错误)
   -  添加第三个筛选器 type(账户类型)
   -  新增 antigravity 平台和 inactive 状态选项

2. UserBalanceModal.vue
   -  保留 console.error 错误日志
   -  添加输入验证(金额校验、余额不足检查)
   -  使用 appStore.showError 向用户显示友好错误

3. AccountsView.vue
   -  保留所有 console.error 错误日志(避免 no-empty 错误)
   -  使用新 API:clearRateLimit 和 setSchedulable

4. UsageView.vue
   -  添加 console.error 错误日志
   -  添加图表功能(模型分布、使用趋势)
   -  添加粒度选择(按天/按小时)
   -  保留 XLSX 动态导入优化

**测试结果:**
-  Go tests: PASS
-  golangci-lint: 0 issues
-  ESLint: 0 errors
-  TypeScript: PASS

🤖 Generated with [Claude Code](https://claude.com/claude-code

)
Co-Authored-By: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parents 4790aced cb72262a
...@@ -33,6 +33,10 @@ type CreateGroupRequest struct { ...@@ -33,6 +33,10 @@ type CreateGroupRequest struct {
DailyLimitUSD *float64 `json:"daily_limit_usd"` DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
} }
// UpdateGroupRequest represents update group request // UpdateGroupRequest represents update group request
...@@ -47,6 +51,10 @@ type UpdateGroupRequest struct { ...@@ -47,6 +51,10 @@ type UpdateGroupRequest struct {
DailyLimitUSD *float64 `json:"daily_limit_usd"` DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
} }
// List handles listing all groups with pagination // List handles listing all groups with pagination
...@@ -139,6 +147,9 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -139,6 +147,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
DailyLimitUSD: req.DailyLimitUSD, DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD, WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -174,6 +185,9 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -174,6 +185,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
DailyLimitUSD: req.DailyLimitUSD, DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD, WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
......
...@@ -102,8 +102,9 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -102,8 +102,9 @@ func (h *UsageHandler) List(c *gin.Context) {
// Parse date range // Parse date range
var startTime, endTime *time.Time var startTime, endTime *time.Time
userTZ := c.Query("timezone") // Get user's timezone from request
if startDateStr := c.Query("start_date"); startDateStr != "" { if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", startDateStr) t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return return
...@@ -112,7 +113,7 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -112,7 +113,7 @@ func (h *UsageHandler) List(c *gin.Context) {
} }
if endDateStr := c.Query("end_date"); endDateStr != "" { if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", endDateStr) t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return return
...@@ -172,7 +173,8 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -172,7 +173,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
} }
// Parse date range // Parse date range
now := timezone.Now() userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
var startTime, endTime time.Time var startTime, endTime time.Time
startDateStr := c.Query("start_date") startDateStr := c.Query("start_date")
...@@ -180,12 +182,12 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -180,12 +182,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if startDateStr != "" && endDateStr != "" { if startDateStr != "" && endDateStr != "" {
var err error var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr) startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return return
} }
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr) endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return return
...@@ -195,13 +197,13 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -195,13 +197,13 @@ func (h *UsageHandler) Stats(c *gin.Context) {
period := c.DefaultQuery("period", "today") period := c.DefaultQuery("period", "today")
switch period { switch period {
case "today": case "today":
startTime = timezone.StartOfDay(now) startTime = timezone.StartOfDayInUserLocation(now, userTZ)
case "week": case "week":
startTime = now.AddDate(0, 0, -7) startTime = now.AddDate(0, 0, -7)
case "month": case "month":
startTime = now.AddDate(0, -1, 0) startTime = now.AddDate(0, -1, 0)
default: default:
startTime = timezone.StartOfDay(now) startTime = timezone.StartOfDayInUserLocation(now, userTZ)
} }
endTime = now endTime = now
} }
......
...@@ -78,6 +78,9 @@ func GroupFromServiceShallow(g *service.Group) *Group { ...@@ -78,6 +78,9 @@ func GroupFromServiceShallow(g *service.Group) *Group {
DailyLimitUSD: g.DailyLimitUSD, DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD, WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD, MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount, AccountCount: g.AccountCount,
...@@ -247,6 +250,8 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog { ...@@ -247,6 +250,8 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
Stream: l.Stream, Stream: l.Stream,
DurationMs: l.DurationMs, DurationMs: l.DurationMs,
FirstTokenMs: l.FirstTokenMs, FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount,
ImageSize: l.ImageSize,
CreatedAt: l.CreatedAt, CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User), User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey), APIKey: APIKeyFromService(l.APIKey),
......
...@@ -47,6 +47,11 @@ type Group struct { ...@@ -47,6 +47,11 @@ type Group struct {
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
...@@ -169,6 +174,10 @@ type UsageLog struct { ...@@ -169,6 +174,10 @@ type UsageLog struct {
DurationMs *int `json:"duration_ms"` DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"` FirstTokenMs *int `json:"first_token_ms"`
// 图片生成字段
ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
......
...@@ -88,8 +88,9 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -88,8 +88,9 @@ func (h *UsageHandler) List(c *gin.Context) {
// Parse date range // Parse date range
var startTime, endTime *time.Time var startTime, endTime *time.Time
userTZ := c.Query("timezone") // Get user's timezone from request
if startDateStr := c.Query("start_date"); startDateStr != "" { if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", startDateStr) t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return return
...@@ -98,7 +99,7 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -98,7 +99,7 @@ func (h *UsageHandler) List(c *gin.Context) {
} }
if endDateStr := c.Query("end_date"); endDateStr != "" { if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", endDateStr) t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return return
...@@ -194,7 +195,8 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -194,7 +195,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
} }
// 获取时间范围参数 // 获取时间范围参数
now := timezone.Now() userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
var startTime, endTime time.Time var startTime, endTime time.Time
// 优先使用 start_date 和 end_date 参数 // 优先使用 start_date 和 end_date 参数
...@@ -204,12 +206,12 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -204,12 +206,12 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if startDateStr != "" && endDateStr != "" { if startDateStr != "" && endDateStr != "" {
// 使用自定义日期范围 // 使用自定义日期范围
var err error var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr) startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return return
} }
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr) endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return return
...@@ -221,13 +223,13 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -221,13 +223,13 @@ func (h *UsageHandler) Stats(c *gin.Context) {
period := c.DefaultQuery("period", "today") period := c.DefaultQuery("period", "today")
switch period { switch period {
case "today": case "today":
startTime = timezone.StartOfDay(now) startTime = timezone.StartOfDayInUserLocation(now, userTZ)
case "week": case "week":
startTime = now.AddDate(0, 0, -7) startTime = now.AddDate(0, 0, -7)
case "month": case "month":
startTime = now.AddDate(0, -1, 0) startTime = now.AddDate(0, -1, 0)
default: default:
startTime = timezone.StartOfDay(now) startTime = timezone.StartOfDayInUserLocation(now, userTZ)
} }
endTime = now endTime = now
} }
...@@ -248,31 +250,33 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -248,31 +250,33 @@ func (h *UsageHandler) Stats(c *gin.Context) {
} }
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard // parseUserTimeRange parses start_date, end_date query parameters for user dashboard
// Uses user's timezone if provided, otherwise falls back to server timezone
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) { func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now() userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
startDate := c.Query("start_date") startDate := c.Query("start_date")
endDate := c.Query("end_date") endDate := c.Query("end_date")
var startTime, endTime time.Time var startTime, endTime time.Time
if startDate != "" { if startDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil { if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
startTime = t startTime = t
} else { } else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7)) startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
} }
} else { } else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7)) startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
} }
if endDate != "" { if endDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil { if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date endTime = t.Add(24 * time.Hour) // Include the end date
} else { } else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1)) endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
} }
} else { } else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1)) endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
} }
return startTime, endTime return startTime, endTime
......
...@@ -67,6 +67,13 @@ type GeminiGenerationConfig struct { ...@@ -67,6 +67,13 @@ type GeminiGenerationConfig struct {
TopK *int `json:"topK,omitempty"` TopK *int `json:"topK,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"` StopSequences []string `json:"stopSequences,omitempty"`
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
}
// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持)
type GeminiImageConfig struct {
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
} }
// GeminiThinkingConfig Gemini thinking 配置 // GeminiThinkingConfig Gemini thinking 配置
......
...@@ -122,3 +122,40 @@ func StartOfMonth(t time.Time) time.Time { ...@@ -122,3 +122,40 @@ func StartOfMonth(t time.Time) time.Time {
func ParseInLocation(layout, value string) (time.Time, error) { func ParseInLocation(layout, value string) (time.Time, error) {
return time.ParseInLocation(layout, value, Location()) return time.ParseInLocation(layout, value, Location())
} }
// ParseInUserLocation parses a time string in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func ParseInUserLocation(layout, value, userTZ string) (time.Time, error) {
loc := Location() // default to server timezone
if userTZ != "" {
if userLoc, err := time.LoadLocation(userTZ); err == nil {
loc = userLoc
}
}
return time.ParseInLocation(layout, value, loc)
}
// NowInUserLocation returns the current time in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func NowInUserLocation(userTZ string) time.Time {
if userTZ == "" {
return Now()
}
if userLoc, err := time.LoadLocation(userTZ); err == nil {
return time.Now().In(userLoc)
}
return Now()
}
// StartOfDayInUserLocation returns the start of the given day in the user's timezone.
// If userTZ is empty or invalid, falls back to the configured server timezone.
func StartOfDayInUserLocation(t time.Time, userTZ string) time.Time {
loc := Location()
if userTZ != "" {
if userLoc, err := time.LoadLocation(userTZ); err == nil {
loc = userLoc
}
}
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}
...@@ -773,10 +773,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -773,10 +773,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
idx++ idx++
} }
if updates.ProxyID != nil { if updates.ProxyID != nil {
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *updates.ProxyID == 0 {
setClauses = append(setClauses, "proxy_id = NULL")
} else {
setClauses = append(setClauses, "proxy_id = $"+itoa(idx)) setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
args = append(args, *updates.ProxyID) args = append(args, *updates.ProxyID)
idx++ idx++
} }
}
if updates.Concurrency != nil { if updates.Concurrency != nil {
setClauses = append(setClauses, "concurrency = $"+itoa(idx)) setClauses = append(setClauses, "concurrency = $"+itoa(idx))
args = append(args, *updates.Concurrency) args = append(args, *updates.Concurrency)
......
...@@ -321,6 +321,9 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -321,6 +321,9 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DailyLimitUSD: g.DailyLimitUsd, DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd, WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd, MonthlyLimitUSD: g.MonthlyLimitUsd,
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
......
...@@ -56,7 +56,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { ...@@ -56,7 +56,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 确保数据库 schema 已准备就绪。 // 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源(source of truth)。 // SQL 迁移文件是 schema 的权威来源(source of truth)。
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。 // 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) migrationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel() defer cancel()
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil { if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露 _ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
......
...@@ -43,6 +43,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -43,6 +43,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays) SetDefaultValidityDays(groupIn.DefaultValidityDays)
created, err := builder.Save(ctx) created, err := builder.Save(ctx)
...@@ -80,6 +83,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -80,6 +83,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, image_count, image_size, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
...@@ -109,6 +109,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -109,6 +109,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
stream, stream,
duration_ms, duration_ms,
first_token_ms, first_token_ms,
image_count,
image_size,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $1, $2, $3, $4, $5,
...@@ -116,7 +118,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -116,7 +118,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25 $20, $21, $22, $23, $24,
$25, $26, $27
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -126,6 +129,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -126,6 +129,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
subscriptionID := nullInt64(log.SubscriptionID) subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs) duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs) firstToken := nullInt(log.FirstTokenMs)
imageSize := nullString(log.ImageSize)
var requestIDArg any var requestIDArg any
if requestID != "" { if requestID != "" {
...@@ -157,6 +161,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -157,6 +161,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.Stream, log.Stream,
duration, duration,
firstToken, firstToken,
log.ImageCount,
imageSize,
createdAt, createdAt,
} }
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
...@@ -1789,6 +1795,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1789,6 +1795,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
stream bool stream bool
durationMs sql.NullInt64 durationMs sql.NullInt64
firstTokenMs sql.NullInt64 firstTokenMs sql.NullInt64
imageCount int
imageSize sql.NullString
createdAt time.Time createdAt time.Time
) )
...@@ -1818,6 +1826,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1818,6 +1826,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&stream, &stream,
&durationMs, &durationMs,
&firstTokenMs, &firstTokenMs,
&imageCount,
&imageSize,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
...@@ -1844,6 +1854,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1844,6 +1854,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
RateMultiplier: rateMultiplier, RateMultiplier: rateMultiplier,
BillingType: int8(billingType), BillingType: int8(billingType),
Stream: stream, Stream: stream,
ImageCount: imageCount,
CreatedAt: createdAt, CreatedAt: createdAt,
} }
...@@ -1866,6 +1877,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1866,6 +1877,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
value := int(firstTokenMs.Int64) value := int(firstTokenMs.Int64)
log.FirstTokenMs = &value log.FirstTokenMs = &value
} }
if imageSize.Valid {
log.ImageSize = &imageSize.String
}
return log, nil return log, nil
} }
...@@ -1938,6 +1952,13 @@ func nullInt(v *int) sql.NullInt64 { ...@@ -1938,6 +1952,13 @@ func nullInt(v *int) sql.NullInt64 {
return sql.NullInt64{Int64: int64(*v), Valid: true} return sql.NullInt64{Int64: int64(*v), Valid: true}
} }
func nullString(v *string) sql.NullString {
if v == nil || *v == "" {
return sql.NullString{}
}
return sql.NullString{String: *v, Valid: true}
}
func setToSlice(set map[int64]struct{}) []int64 { func setToSlice(set map[int64]struct{}) []int64 {
out := make([]int64, 0, len(set)) out := make([]int64, 0, len(set))
for id := range set { for id := range set {
......
...@@ -329,17 +329,20 @@ func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount flo ...@@ -329,17 +329,20 @@ func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount flo
return nil return nil
} }
// DeductBalance 扣除用户余额
// 透支策略:允许余额变为负数,确保当前请求能够完成
// 中间件会阻止余额 <= 0 的用户发起后续请求
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
n, err := client.User.Update(). n, err := client.User.Update().
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)). Where(dbuser.IDEQ(id)).
AddBalance(-amount). AddBalance(-amount).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return err return err
} }
if n == 0 { if n == 0 {
return service.ErrInsufficientBalance return service.ErrUserNotFound
} }
return nil return nil
} }
......
...@@ -290,9 +290,14 @@ func (s *UserRepoSuite) TestDeductBalance() { ...@@ -290,9 +290,14 @@ func (s *UserRepoSuite) TestDeductBalance() {
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5}) user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5})
// 透支策略:允许扣除超过余额的金额
err := s.repo.DeductBalance(s.ctx, user.ID, 999) err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance") s.Require().NoError(err, "DeductBalance should allow overdraft")
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
// 验证余额变为负数
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-994.0, got.Balance, 1e-6, "Balance should be negative after overdraft")
} }
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
...@@ -306,6 +311,19 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { ...@@ -306,6 +311,19 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
s.Require().InDelta(0.0, got.Balance, 1e-6) s.Require().InDelta(0.0, got.Balance, 1e-6)
} }
func (s *UserRepoSuite) TestDeductBalance_AllowsOverdraft() {
user := s.mustCreateUser(&service.User{Email: "overdraft@test.com", Balance: 5.0})
// 扣除超过余额的金额 - 应该成功
err := s.repo.DeductBalance(s.ctx, user.ID, 10.0)
s.Require().NoError(err, "DeductBalance should allow overdraft")
// 验证余额为负
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-5.0, got.Balance, 1e-6, "Balance should be -5.0 after overdraft")
}
// --- Concurrency --- // --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() { func (s *UserRepoSuite) TestUpdateConcurrency() {
...@@ -477,9 +495,12 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { ...@@ -477,9 +495,12 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().NoError(err, "GetByID after DeductBalance") s.Require().NoError(err, "GetByID after DeductBalance")
s.Require().InDelta(7.5, got4.Balance, 1e-6) s.Require().InDelta(7.5, got4.Balance, 1e-6)
// 透支策略:允许扣除超过余额的金额
err = s.repo.DeductBalance(s.ctx, user1.ID, 999) err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance") s.Require().NoError(err, "DeductBalance should allow overdraft")
s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error") gotOverdraft, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after overdraft")
s.Require().Less(gotOverdraft.Balance, 0.0, "Balance should be negative after overdraft")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency") s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID) got5, err := s.repo.GetByID(s.ctx, user1.ID)
...@@ -511,6 +532,6 @@ func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() { ...@@ -511,6 +532,6 @@ func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
func (s *UserRepoSuite) TestDeductBalance_NotFound() { func (s *UserRepoSuite) TestDeductBalance_NotFound() {
err := s.repo.DeductBalance(s.ctx, 999999, 5) err := s.repo.DeductBalance(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user") s.Require().Error(err, "expected error for non-existent user")
// DeductBalance 在用户不存在时返回 ErrInsufficientBalance 因为 WHERE 条件不匹配 // DeductBalance 在用户不存在时返回 ErrUserNotFound
s.Require().ErrorIs(err, service.ErrInsufficientBalance) s.Require().ErrorIs(err, service.ErrUserNotFound)
} }
...@@ -241,6 +241,8 @@ func TestAPIContracts(t *testing.T) { ...@@ -241,6 +241,8 @@ func TestAPIContracts(t *testing.T) {
"stream": true, "stream": true,
"duration_ms": 100, "duration_ms": 100,
"first_token_ms": 50, "first_token_ms": 50,
"image_count": 0,
"image_size": null,
"created_at": "2025-01-02T03:04:05Z" "created_at": "2025-01-02T03:04:05Z"
} }
], ],
......
...@@ -98,6 +98,10 @@ type CreateGroupInput struct { ...@@ -98,6 +98,10 @@ type CreateGroupInput struct {
DailyLimitUSD *float64 // 日限额 (USD) DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
} }
type UpdateGroupInput struct { type UpdateGroupInput struct {
...@@ -111,6 +115,10 @@ type UpdateGroupInput struct { ...@@ -111,6 +115,10 @@ type UpdateGroupInput struct {
DailyLimitUSD *float64 // 日限额 (USD) DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
} }
type CreateAccountInput struct { type CreateAccountInput struct {
...@@ -498,6 +506,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -498,6 +506,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
// 图片价格:负数表示清除(使用默认价格),0 保留(表示免费)
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
...@@ -509,6 +522,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -509,6 +522,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
DailyLimitUSD: dailyLimit, DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit, WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit, MonthlyLimitUSD: monthlyLimit,
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
...@@ -524,6 +540,14 @@ func normalizeLimit(limit *float64) *float64 { ...@@ -524,6 +540,14 @@ func normalizeLimit(limit *float64) *float64 {
return limit return limit
} }
// normalizePrice 将负数转换为 nil(表示使用默认价格),0 保留(表示免费)
func normalizePrice(price *float64) *float64 {
if price == nil || *price < 0 {
return nil
}
return price
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
...@@ -563,6 +587,16 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -563,6 +587,16 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MonthlyLimitUSD != nil { if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
} }
// 图片生成计费配置:负数表示清除(使用默认价格)
if input.ImagePrice1K != nil {
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
}
if input.ImagePrice2K != nil {
group.ImagePrice2K = normalizePrice(input.ImagePrice2K)
}
if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
...@@ -702,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -702,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Extra = input.Extra account.Extra = input.Extra
} }
if input.ProxyID != nil { if input.ProxyID != nil {
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *input.ProxyID == 0 {
account.ProxyID = nil
} else {
account.ProxyID = input.ProxyID account.ProxyID = input.ProxyID
}
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
} }
// 只在指针非 nil 时更新 Concurrency(支持设置为 0) // 只在指针非 nil 时更新 Concurrency(支持设置为 0)
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
type groupRepoStubForAdmin struct {
created *Group // 记录 Create 调用的参数
updated *Group // 记录 Update 调用的参数
getByID *Group // GetByID 返回值
getErr error // GetByID 返回的错误
}
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
s.created = g
return nil
}
func (s *groupRepoStubForAdmin) Update(_ context.Context, g *Group) error {
s.updated = g
return nil
}
func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, error) {
if s.getErr != nil {
return nil, s.getErr
}
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForAdmin) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForAdmin) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
price1K := 0.10
price2K := 0.15
price4K := 0.30
input := &CreateGroupInput{
Name: "test-group",
Description: "Test group",
Platform: PlatformAntigravity,
RateMultiplier: 1.0,
ImagePrice1K: &price1K,
ImagePrice2K: &price2K,
ImagePrice4K: &price4K,
}
group, err := svc.CreateGroup(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, group)
// 验证 repo 收到了正确的字段
require.NotNil(t, repo.created)
require.NotNil(t, repo.created.ImagePrice1K)
require.NotNil(t, repo.created.ImagePrice2K)
require.NotNil(t, repo.created.ImagePrice4K)
require.InDelta(t, 0.10, *repo.created.ImagePrice1K, 0.0001)
require.InDelta(t, 0.15, *repo.created.ImagePrice2K, 0.0001)
require.InDelta(t, 0.30, *repo.created.ImagePrice4K, 0.0001)
}
// TestAdminService_CreateGroup_NilImagePricing 测试 ImagePrice 为 nil 时正常创建
func TestAdminService_CreateGroup_NilImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
input := &CreateGroupInput{
Name: "test-group",
Description: "Test group",
Platform: PlatformAntigravity,
RateMultiplier: 1.0,
// ImagePrice 字段全部为 nil
}
group, err := svc.CreateGroup(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, group)
// 验证 ImagePrice 字段为 nil
require.NotNil(t, repo.created)
require.Nil(t, repo.created.ImagePrice1K)
require.Nil(t, repo.created.ImagePrice2K)
require.Nil(t, repo.created.ImagePrice4K)
}
// TestAdminService_UpdateGroup_WithImagePricing 测试更新分组时 ImagePrice 字段正确更新
func TestAdminService_UpdateGroup_WithImagePricing(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformAntigravity,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
price1K := 0.12
price2K := 0.18
price4K := 0.36
input := &UpdateGroupInput{
ImagePrice1K: &price1K,
ImagePrice2K: &price2K,
ImagePrice4K: &price4K,
}
group, err := svc.UpdateGroup(context.Background(), 1, input)
require.NoError(t, err)
require.NotNil(t, group)
// 验证 repo 收到了更新后的字段
require.NotNil(t, repo.updated)
require.NotNil(t, repo.updated.ImagePrice1K)
require.NotNil(t, repo.updated.ImagePrice2K)
require.NotNil(t, repo.updated.ImagePrice4K)
require.InDelta(t, 0.12, *repo.updated.ImagePrice1K, 0.0001)
require.InDelta(t, 0.18, *repo.updated.ImagePrice2K, 0.0001)
require.InDelta(t, 0.36, *repo.updated.ImagePrice4K, 0.0001)
}
// TestAdminService_UpdateGroup_PartialImagePricing 测试仅更新部分 ImagePrice 字段
func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
oldPrice2K := 0.15
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformAntigravity,
Status: StatusActive,
ImagePrice2K: &oldPrice2K, // 已有 2K 价格
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
// 只更新 1K 价格
price1K := 0.10
input := &UpdateGroupInput{
ImagePrice1K: &price1K,
// ImagePrice2K 和 ImagePrice4K 为 nil,不更新
}
group, err := svc.UpdateGroup(context.Background(), 1, input)
require.NoError(t, err)
require.NotNil(t, group)
// 验证:1K 被更新,2K 保持原值,4K 仍为 nil
require.NotNil(t, repo.updated)
require.NotNil(t, repo.updated.ImagePrice1K)
require.InDelta(t, 0.10, *repo.updated.ImagePrice1K, 0.0001)
require.NotNil(t, repo.updated.ImagePrice2K)
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
require.Nil(t, repo.updated.ImagePrice4K)
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
mathrand "math/rand"
"net/http" "net/http"
"strings" "strings"
"sync/atomic" "sync/atomic"
...@@ -405,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -405,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody) upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -414,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -414,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt) if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue continue
} }
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
...@@ -427,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -427,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt) if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue continue
} }
// 所有重试都失败,标记限流状态 // 所有重试都失败,标记限流状态
...@@ -845,6 +860,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -845,6 +860,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
} }
// 解析请求以获取 image_size(用于图片计费)
imageSize := s.extractImageSize(body)
switch action { switch action {
case "generateContent", "streamGenerateContent": case "generateContent", "streamGenerateContent":
// ok // ok
...@@ -901,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -901,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody) upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -910,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -910,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt) if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue continue
} }
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
...@@ -923,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -923,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt) if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue continue
} }
// 所有重试都失败,标记限流状态 // 所有重试都失败,标记限流状态
...@@ -1030,6 +1062,13 @@ handleSuccess: ...@@ -1030,6 +1062,13 @@ handleSuccess:
usage = &ClaudeUsage{} usage = &ClaudeUsage{}
} }
// 判断是否为图片生成模型
imageCount := 0
if isImageGenerationModel(mappedModel) {
// Gemini 图片生成 API 每次请求只生成一张图片(API 限制)
imageCount = 1
}
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
...@@ -1037,6 +1076,8 @@ handleSuccess: ...@@ -1037,6 +1076,8 @@ handleSuccess:
Stream: stream, Stream: stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil }, nil
} }
...@@ -1058,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) ...@@ -1058,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
} }
} }
func sleepAntigravityBackoff(attempt int) { // sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑 // 返回 true 表示正常完成等待,false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay {
delay = geminiRetryMaxDelay
}
// +/- 20% jitter
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
sleepFor := delay + jitter
if sleepFor < 0 {
sleepFor = 0
}
select {
case <-ctx.Done():
return false
case <-time.After(sleepFor):
return true
}
} }
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) { func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
...@@ -1523,3 +1584,36 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -1523,3 +1584,36 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
} }
} }
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
var req antigravity.GeminiRequest
if err := json.Unmarshal(body, &req); err != nil {
return "2K" // 默认 2K
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
if size == "1K" || size == "2K" || size == "4K" {
return size
}
}
return "2K" // 默认 2K
}
// isImageGenerationModel 判断模型是否为图片生成模型
// 支持的模型:gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等
func isImageGenerationModel(model string) bool {
modelLower := strings.ToLower(model)
// 移除 models/ 前缀
modelLower = strings.TrimPrefix(modelLower, "models/")
// 精确匹配或前缀匹配
return modelLower == "gemini-3-pro-image" ||
modelLower == "gemini-3-pro-image-preview" ||
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
modelLower == "gemini-2.5-flash-image" ||
modelLower == "gemini-2.5-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestIsImageGenerationModel_GeminiProImage 测试 gemini-3-pro-image 识别
func TestIsImageGenerationModel_GeminiProImage(t *testing.T) {
require.True(t, isImageGenerationModel("gemini-3-pro-image"))
require.True(t, isImageGenerationModel("gemini-3-pro-image-preview"))
require.True(t, isImageGenerationModel("models/gemini-3-pro-image"))
}
// TestIsImageGenerationModel_GeminiFlashImage 测试 gemini-2.5-flash-image 识别
func TestIsImageGenerationModel_GeminiFlashImage(t *testing.T) {
require.True(t, isImageGenerationModel("gemini-2.5-flash-image"))
require.True(t, isImageGenerationModel("gemini-2.5-flash-image-preview"))
}
// TestIsImageGenerationModel_RegularModel 测试普通模型不被识别为图片模型
func TestIsImageGenerationModel_RegularModel(t *testing.T) {
require.False(t, isImageGenerationModel("claude-3-opus"))
require.False(t, isImageGenerationModel("claude-sonnet-4-20250514"))
require.False(t, isImageGenerationModel("gpt-4o"))
require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型
require.False(t, isImageGenerationModel("gemini-2.5-flash"))
// 验证不会误匹配包含关键词的自定义模型名
require.False(t, isImageGenerationModel("my-gemini-3-pro-image-test"))
require.False(t, isImageGenerationModel("custom-gemini-2.5-flash-image-wrapper"))
}
// TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感
func TestIsImageGenerationModel_CaseInsensitive(t *testing.T) {
require.True(t, isImageGenerationModel("GEMINI-3-PRO-IMAGE"))
require.True(t, isImageGenerationModel("Gemini-3-Pro-Image"))
require.True(t, isImageGenerationModel("GEMINI-2.5-FLASH-IMAGE"))
}
// TestExtractImageSize_ValidSizes 测试有效尺寸解析
func TestExtractImageSize_ValidSizes(t *testing.T) {
svc := &AntigravityGatewayService{}
// 1K
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`)
require.Equal(t, "1K", svc.extractImageSize(body))
// 2K
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
// 4K
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`)
require.Equal(t, "4K", svc.extractImageSize(body))
}
// TestExtractImageSize_CaseInsensitive 测试大小写不敏感
func TestExtractImageSize_CaseInsensitive(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`)
require.Equal(t, "1K", svc.extractImageSize(body))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`)
require.Equal(t, "4K", svc.extractImageSize(body))
}
// TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K
func TestExtractImageSize_Default(t *testing.T) {
svc := &AntigravityGatewayService{}
// 无 generationConfig
body := []byte(`{"contents":[]}`)
require.Equal(t, "2K", svc.extractImageSize(body))
// 有 generationConfig 但无 imageConfig
body = []byte(`{"generationConfig":{"temperature":0.7}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
// 有 imageConfig 但无 imageSize
body = []byte(`{"generationConfig":{"imageConfig":{}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
}
// TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K
func TestExtractImageSize_InvalidJSON(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`not valid json`)
require.Equal(t, "2K", svc.extractImageSize(body))
body = []byte(`{"broken":`)
require.Equal(t, "2K", svc.extractImageSize(body))
}
// TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K
func TestExtractImageSize_EmptySize(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
// 空格
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
}
// TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K
func TestExtractImageSize_InvalidSize(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
}
...@@ -295,3 +295,88 @@ func (s *BillingService) ForceUpdatePricing() error { ...@@ -295,3 +295,88 @@ func (s *BillingService) ForceUpdatePricing() error {
} }
return fmt.Errorf("pricing service not initialized") return fmt.Errorf("pricing service not initialized")
} }
// ImagePriceConfig 图片计费配置
type ImagePriceConfig struct {
Price1K *float64 // 1K 尺寸价格(nil 表示使用默认值)
Price2K *float64 // 2K 尺寸价格(nil 表示使用默认值)
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
}
// CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K"
// imageCount: 生成的图片数量
// groupConfig: 分组配置的价格(可能为 nil,表示使用默认值)
// rateMultiplier: 费率倍数
func (s *BillingService) CalculateImageCost(model string, imageSize string, imageCount int, groupConfig *ImagePriceConfig, rateMultiplier float64) *CostBreakdown {
if imageCount <= 0 {
return &CostBreakdown{}
}
// 获取单价
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)
// 计算总费用
totalCost := unitPrice * float64(imageCount)
// 应用倍率
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// getImageUnitPrice 获取图片单价
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
// 优先使用分组配置的价格
if groupConfig != nil {
switch imageSize {
case "1K":
if groupConfig.Price1K != nil {
return *groupConfig.Price1K
}
case "2K":
if groupConfig.Price2K != nil {
return *groupConfig.Price2K
}
case "4K":
if groupConfig.Price4K != nil {
return *groupConfig.Price4K
}
}
}
// 回退到 LiteLLM 默认价格
return s.getDefaultImagePrice(model, imageSize)
}
// getDefaultImagePrice 获取 LiteLLM 默认图片价格
func (s *BillingService) getDefaultImagePrice(model string, imageSize string) float64 {
basePrice := 0.0
// 从 PricingService 获取 output_cost_per_image
if s.pricingService != nil {
pricing := s.pricingService.GetModelPricing(model)
if pricing != nil && pricing.OutputCostPerImage > 0 {
basePrice = pricing.OutputCostPerImage
}
}
// 如果没有找到价格,使用硬编码默认值($0.134,来自 gemini-3-pro-image-preview)
if basePrice <= 0 {
basePrice = 0.134
}
// 4K 尺寸翻倍
if imageSize == "4K" {
return basePrice * 2
}
return basePrice
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment