Unverified Commit c6b3de11 authored by NepetaLemon's avatar NepetaLemon Committed by GitHub
Browse files

ci(backend): 添加 github actions (#10)

## 变更内容

### CI/CD
- 添加 GitHub Actions 工作流(test + golangci-lint)
- 添加 golangci-lint 配置,启用 errcheck/govet/staticcheck/unused/depguard
- 通过 depguard 强制 service 层不能直接导入 repository

### 错误处理修复
- 修复 CSV 写入、SSE 流式输出、随机数生成等未处理的错误
- GenerateRedeemCode() 现在返回 error

### 资源泄露修复
- 统一使用 defer func() { _ = xxx.Close() }() 模式

### 代码清理
- 移除未使用的常量
- 简化 nil map 检查
- 统一代码格式
parent f1325e9a
name: CI
on:
push:
pull_request:
permissions:
contents: read
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
check-latest: true
cache: true
- name: Run tests
run: go test ./...
golangci-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
check-latest: true
cache: true
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: latest
args: --timeout=5m
version: "2"
linters:
default: none
enable:
- depguard
- errcheck
- govet
- ineffassign
- staticcheck
- unused
settings:
depguard:
rules:
# Enforce: service must not depend on repository.
service-no-repository:
list-mode: original
files:
- internal/service/**
deny:
- pkg: sub2api/internal/repository
desc: "service must not import repository"
formatters:
enable:
- gofmt
...@@ -52,7 +52,7 @@ type PricingConfig struct { ...@@ -52,7 +52,7 @@ type PricingConfig struct {
type ServerConfig struct { type ServerConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
} }
...@@ -163,7 +163,7 @@ func setDefaults() { ...@@ -163,7 +163,7 @@ func setDefaults() {
viper.SetDefault("server.port", 8080) viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug") viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
// Database // Database
viper.SetDefault("database.host", "localhost") viper.SetDefault("database.host", "localhost")
...@@ -210,10 +210,10 @@ func setDefaults() { ...@@ -210,10 +210,10 @@ func setDefaults() {
// TokenRefresh // TokenRefresh
viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新 viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
......
...@@ -573,7 +573,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { ...@@ -573,7 +573,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// For API Key accounts: return models based on model_mapping // For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping() mapping := account.GetModelMapping()
if mapping == nil || len(mapping) == 0 { if len(mapping) == 0 {
// No mapping configured, return default models // No mapping configured, return default models
response.Success(c, claude.DefaultModels) response.Success(c, claude.DefaultModels)
return return
......
...@@ -236,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) { ...@@ -236,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
response.Paginated(c, accounts, total, page, pageSize) response.Paginated(c, accounts, total, page, pageSize)
} }
// BatchCreateProxyItem represents a single proxy in batch create request // BatchCreateProxyItem represents a single proxy in batch create request
type BatchCreateProxyItem struct { type BatchCreateProxyItem struct {
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
......
...@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) { ...@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
func (h *RedeemHandler) GetStats(c *gin.Context) { func (h *RedeemHandler) GetStats(c *gin.Context) {
// Return mock data for now // Return mock data for now
response.Success(c, gin.H{ response.Success(c, gin.H{
"total_codes": 0, "total_codes": 0,
"active_codes": 0, "active_codes": 0,
"used_codes": 0, "used_codes": 0,
"expired_codes": 0, "expired_codes": 0,
"total_value_distributed": 0.0, "total_value_distributed": 0.0,
"by_type": gin.H{ "by_type": gin.H{
"balance": 0, "balance": 0,
...@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) { ...@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
writer := csv.NewWriter(&buf) writer := csv.NewWriter(&buf)
// Write header // Write header
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}) if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
// Write data rows // Write data rows
for _, code := range codes { for _, code := range codes {
...@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) { ...@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
if code.UsedAt != nil { if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05") usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
} }
writer.Write([]string{ if err := writer.Write([]string{
fmt.Sprintf("%d", code.ID), fmt.Sprintf("%d", code.ID),
code.Code, code.Code,
code.Type, code.Type,
...@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) { ...@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
usedBy, usedBy,
usedAt, usedAt,
code.CreatedAt.Format("2006-01-02 15:04:05"), code.CreatedAt.Format("2006-01-02 15:04:05"),
}) }); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
} }
writer.Flush() writer.Flush()
if err := writer.Error(); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
c.Header("Content-Type", "text/csv") c.Header("Content-Type", "text/csv")
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv") c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
......
...@@ -268,7 +268,9 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id ...@@ -268,7 +268,9 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id
c.Header("X-Accel-Buffering", "no") c.Header("X-Accel-Buffering", "no")
*streamStarted = true *streamStarted = true
} }
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n") if _, err := fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n"); err != nil {
return nil, err
}
flusher.Flush() flusher.Flush()
} }
...@@ -414,7 +416,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e ...@@ -414,7 +416,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
if ok { if ok {
// Send error event in SSE format // Send error event in SSE format
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
fmt.Fprint(c.Writer, errorEvent) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush() flusher.Flush()
} }
return return
...@@ -574,11 +578,11 @@ func sendMockWarmupStream(c *gin.Context, model string) { ...@@ -574,11 +578,11 @@ func sendMockWarmupStream(c *gin.Context, model string) {
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截) // sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) { func sendMockWarmupResponse(c *gin.Context, model string) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup", "id": "msg_mock_warmup",
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": model, "model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}}, "content": []gin.H{{"type": "text", "text": "New Conversation"}},
"stop_reason": "end_turn", "stop_reason": "end_turn",
"usage": gin.H{ "usage": gin.H{
"input_tokens": 10, "input_tokens": 10,
......
...@@ -40,8 +40,8 @@ type Account struct { ...@@ -40,8 +40,8 @@ type Account struct {
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息 Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"` ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"` Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高 Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"` ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"` LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"` CreatedAt time.Time `gorm:"not null" json:"created_at"`
...@@ -163,7 +163,7 @@ func (a *Account) GetModelMapping() map[string]string { ...@@ -163,7 +163,7 @@ func (a *Account) GetModelMapping() map[string]string {
// 如果没有设置模型映射,则支持所有模型 // 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool { func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 { if len(mapping) == 0 {
return true // 没有映射配置,支持所有模型 return true // 没有映射配置,支持所有模型
} }
_, exists := mapping[requestedModel] _, exists := mapping[requestedModel]
...@@ -174,7 +174,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool { ...@@ -174,7 +174,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
// 如果没有映射,返回原始模型名 // 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 { if len(mapping) == 0 {
return requestedModel return requestedModel
} }
if mappedModel, exists := mapping[requestedModel]; exists { if mappedModel, exists := mapping[requestedModel]; exists {
......
...@@ -13,13 +13,13 @@ const ( ...@@ -13,13 +13,13 @@ const (
) )
type Group struct { type Group struct {
ID int64 `gorm:"primaryKey" json:"id"` ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"` Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"` Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"` RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"` IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段 // 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
......
...@@ -9,15 +9,15 @@ import ( ...@@ -9,15 +9,15 @@ import (
type RedeemCode struct { type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"` ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"` Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数 Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"` UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"` UsedAt *time.Time `json:"used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"` CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段 // 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用) GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用) ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联 // 关联
...@@ -40,8 +40,10 @@ func (r *RedeemCode) CanUse() bool { ...@@ -40,8 +40,10 @@ func (r *RedeemCode) CanUse() bool {
} }
// GenerateRedeemCode 生成唯一的兑换码 // GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() string { func GenerateRedeemCode() (string, error) {
b := make([]byte, 16) b := make([]byte, 16)
rand.Read(b) if _, err := rand.Read(b); err != nil {
return hex.EncodeToString(b) return "", err
}
return hex.EncodeToString(b), nil
} }
...@@ -19,17 +19,17 @@ func (Setting) TableName() string { ...@@ -19,17 +19,17 @@ func (Setting) TableName() string {
// 设置Key常量 // 设置Key常量
const ( const (
// 注册设置 // 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置 // 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址 SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口 SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名 SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储) SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
SettingKeySmtpFrom = "smtp_from" // 发件人地址 SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称 SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置 // Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证 SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
......
...@@ -37,7 +37,7 @@ type UsageLog struct { ...@@ -37,7 +37,7 @@ type UsageLog struct {
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"` OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"` CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"` CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用 TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用 ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率 RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
......
...@@ -9,8 +9,8 @@ import ( ...@@ -9,8 +9,8 @@ import (
) )
type User struct { type User struct {
ID int64 `gorm:"primaryKey" json:"id"` ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"` Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
PasswordHash string `gorm:"size:255;not null" json:"-"` PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"` Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
......
...@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) { ...@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
func TestTimeNowAffected(t *testing.T) { func TestTimeNowAffected(t *testing.T) {
// Reset to UTC first // Reset to UTC first
Init("UTC") if err := Init("UTC"); err != nil {
t.Fatalf("Init failed with UTC: %v", err)
}
utcNow := time.Now() utcNow := time.Now()
// Switch to Shanghai (UTC+8) // Switch to Shanghai (UTC+8)
Init("Asia/Shanghai") if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
shanghaiNow := time.Now() shanghaiNow := time.Now()
// The times should be the same instant, but different timezone representation // The times should be the same instant, but different timezone representation
...@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) { ...@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
} }
func TestToday(t *testing.T) { func TestToday(t *testing.T) {
Init("Asia/Shanghai") if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
today := Today() today := Today()
now := Now() now := Now()
...@@ -75,7 +81,9 @@ func TestToday(t *testing.T) { ...@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
} }
func TestStartOfDay(t *testing.T) { func TestStartOfDay(t *testing.T) {
Init("Asia/Shanghai") if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
// Create a time at 15:30:45 // Create a time at 15:30:45
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location()) testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
...@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) { ...@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
// This test demonstrates why Truncate(24*time.Hour) can be problematic // This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code // and why StartOfDay is more reliable for timezone-aware code
Init("Asia/Shanghai") if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
now := Now() now := Now()
......
...@@ -43,7 +43,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU ...@@ -43,7 +43,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
if err != nil { if err != nil {
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
......
...@@ -38,7 +38,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin ...@@ -38,7 +38,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode) return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
...@@ -63,7 +63,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string ...@@ -63,7 +63,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode) return fmt.Errorf("download returned %d", resp.StatusCode)
...@@ -78,7 +78,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string ...@@ -78,7 +78,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
if err != nil { if err != nil {
return err return err
} }
defer out.Close() defer func() { _ = out.Close() }()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong // SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxSize+1) limited := io.LimitReader(resp.Body, maxSize+1)
...@@ -89,7 +89,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string ...@@ -89,7 +89,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
// Check if we hit the limit (downloaded more than maxSize) // Check if we hit the limit (downloaded more than maxSize)
if written > maxSize { if written > maxSize {
os.Remove(dest) // Clean up partial file _ = os.Remove(dest) // Clean up partial file (best-effort)
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize) return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
} }
...@@ -106,7 +106,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ...@@ -106,7 +106,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode) return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
......
...@@ -33,7 +33,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ...@@ -33,7 +33,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode) return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
...@@ -52,7 +52,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st ...@@ -52,7 +52,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st
if err != nil { if err != nil {
return "", err return "", err
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode) return "", fmt.Errorf("HTTP %d", resp.StatusCode)
......
...@@ -43,7 +43,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ...@@ -43,7 +43,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("proxy connection failed: %w", err) return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
latencyMs := time.Since(startTime).Milliseconds() latencyMs := time.Since(startTime).Milliseconds()
......
...@@ -44,7 +44,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r ...@@ -44,7 +44,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
if err != nil { if err != nil {
return nil, fmt.Errorf("send request: %w", err) return nil, fmt.Errorf("send request: %w", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
var result service.TurnstileVerifyResponse var result service.TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
......
...@@ -51,16 +51,23 @@ func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OA ...@@ -51,16 +51,23 @@ func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OA
} }
// generateSessionString generates a Claude Code style session string // generateSessionString generates a Claude Code style session string
func generateSessionString() string { func generateSessionString() (string, error) {
bytes := make([]byte, 32) bytes := make([]byte, 32)
rand.Read(bytes) if _, err := rand.Read(bytes); err != nil {
return "", err
}
hex64 := hex.EncodeToString(bytes) hex64 := hex.EncodeToString(bytes)
sessionUUID := uuid.New().String() sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID) return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
} }
// createTestPayload creates a Claude Code style test request payload // createTestPayload creates a Claude Code style test request payload
func createTestPayload(modelID string) map[string]interface{} { func createTestPayload(modelID string) (map[string]interface{}, error) {
sessionID, err := generateSessionString()
if err != nil {
return nil, err
}
return map[string]interface{}{ return map[string]interface{}{
"model": modelID, "model": modelID,
"messages": []map[string]interface{}{ "messages": []map[string]interface{}{
...@@ -87,12 +94,12 @@ func createTestPayload(modelID string) map[string]interface{} { ...@@ -87,12 +94,12 @@ func createTestPayload(modelID string) map[string]interface{} {
}, },
}, },
"metadata": map[string]string{ "metadata": map[string]string{
"user_id": generateSessionString(), "user_id": sessionID,
}, },
"max_tokens": 1024, "max_tokens": 1024,
"temperature": 1, "temperature": 1,
"stream": true, "stream": true,
} }, nil
} }
// TestAccountConnection tests an account's connection by sending a test request // TestAccountConnection tests an account's connection by sending a test request
...@@ -116,7 +123,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int ...@@ -116,7 +123,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// For API Key accounts with model mapping, map the model // For API Key accounts with model mapping, map the model
if account.Type == "apikey" { if account.Type == "apikey" {
mapping := account.GetModelMapping() mapping := account.GetModelMapping()
if mapping != nil && len(mapping) > 0 { if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists { if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel testModelID = mappedModel
} }
...@@ -178,7 +185,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int ...@@ -178,7 +185,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
c.Writer.Flush() c.Writer.Flush()
// Create Claude Code style payload (same for all account types) // Create Claude Code style payload (same for all account types)
payload := createTestPayload(testModelID) payload, err := createTestPayload(testModelID)
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create test payload")
}
payloadBytes, _ := json.Marshal(payload) payloadBytes, _ := json.Marshal(payload)
// Send test_start event // Send test_start event
...@@ -216,7 +226,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int ...@@ -216,7 +226,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
...@@ -284,7 +294,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error ...@@ -284,7 +294,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
// sendEvent sends a SSE event to the client // sendEvent sends a SSE event to the client
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event) eventJSON, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON) if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
log.Printf("failed to write SSE event: %v", err)
return
}
c.Writer.Flush() c.Writer.Flush()
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment