Commit 0170d19f authored by song's avatar song
Browse files

merge upstream main

parent 7ade9baa
...@@ -11,7 +11,6 @@ type User struct { ...@@ -11,7 +11,6 @@ type User struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Email string `json:"email"` Email string `json:"email"`
Username string `json:"username"` Username string `json:"username"`
Notes string `json:"notes"`
Role string `json:"role"` Role string `json:"role"`
Balance float64 `json:"balance"` Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
...@@ -24,6 +23,14 @@ type User struct { ...@@ -24,6 +23,14 @@ type User struct {
Subscriptions []UserSubscription `json:"subscriptions,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
} }
// AdminUser 是管理员接口使用的 user DTO(包含敏感/内部字段)。
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
type AdminUser struct {
User
Notes string `json:"notes"`
}
type APIKey struct { type APIKey struct {
ID int64 `json:"id"` ID int64 `json:"id"`
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
...@@ -65,6 +72,15 @@ type Group struct { ...@@ -65,6 +72,15 @@ type Group struct {
// 无效请求兜底分组 // 无效请求兜底分组
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AdminGroup 是管理员接口使用的 group DTO(包含敏感/内部字段)。
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
type AdminGroup struct {
Group
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"` ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"` ModelRoutingEnabled bool `json:"model_routing_enabled"`
...@@ -72,9 +88,6 @@ type Group struct { ...@@ -72,9 +88,6 @@ type Group struct {
// MCP XML 协议注入(仅 antigravity 平台使用) // MCP XML 协议注入(仅 antigravity 平台使用)
MCPXMLInject bool `json:"mcp_xml_inject"` MCPXMLInject bool `json:"mcp_xml_inject"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"` AccountCount int64 `json:"account_count,omitempty"`
} }
...@@ -125,6 +138,15 @@ type Account struct { ...@@ -125,6 +138,15 @@ type Account struct {
MaxSessions *int `json:"max_sessions,omitempty"` MaxSessions *int `json:"max_sessions,omitempty"`
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"` SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
// 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
// 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"` Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
...@@ -184,16 +206,28 @@ type RedeemCode struct { ...@@ -184,16 +206,28 @@ type RedeemCode struct {
Status string `json:"status"` Status string `json:"status"`
UsedBy *int64 `json:"used_by"` UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"` UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"` ValidityDays int `json:"validity_days"`
// Notes is only populated for admin_balance/admin_concurrency types
// so users can see why they were charged or credited
Notes *string `json:"notes,omitempty"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"` Group *Group `json:"group,omitempty"`
} }
// AdminRedeemCode 是管理员接口使用的 redeem code DTO(包含 notes 等字段)。
// 注意:普通用户接口不得返回 notes 等内部信息。
type AdminRedeemCode struct {
RedeemCode
Notes string `json:"notes"`
}
// UsageLog 是普通用户接口使用的 usage log DTO(不包含管理员字段)。
type UsageLog struct { type UsageLog struct {
ID int64 `json:"id"` ID int64 `json:"id"`
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
...@@ -213,14 +247,13 @@ type UsageLog struct { ...@@ -213,14 +247,13 @@ type UsageLog struct {
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"` CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"` CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"` InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"` OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"` CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"` CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"` TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"` ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"` RateMultiplier float64 `json:"rate_multiplier"`
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
BillingType int8 `json:"billing_type"` BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
...@@ -234,18 +267,55 @@ type UsageLog struct { ...@@ -234,18 +267,55 @@ type UsageLog struct {
// User-Agent // User-Agent
UserAgent *string `json:"user_agent"` UserAgent *string `json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
APIKey *APIKey `json:"api_key,omitempty"` APIKey *APIKey `json:"api_key,omitempty"`
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
Group *Group `json:"group,omitempty"` Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"` Subscription *UserSubscription `json:"subscription,omitempty"`
} }
// AdminUsageLog 是管理员接口使用的 usage log DTO(包含管理员字段)。
type AdminUsageLog struct {
UsageLog
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
// IPAddress 用户请求 IP(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
// Account 最小账号信息(避免泄露敏感字段)
Account *AccountSummary `json:"account,omitempty"`
}
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
UserID *int64 `json:"user_id,omitempty"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
type UsageCleanupTask struct {
ID int64 `json:"id"`
Status string `json:"status"`
Filters UsageCleanupFilters `json:"filters"`
CreatedBy int64 `json:"created_by"`
DeletedRows int64 `json:"deleted_rows"`
ErrorMessage *string `json:"error_message,omitempty"`
CanceledBy *int64 `json:"canceled_by,omitempty"`
CanceledAt *time.Time `json:"canceled_at,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AccountSummary is a minimal account info for usage log display. // AccountSummary is a minimal account info for usage log display.
// It intentionally excludes sensitive fields like Credentials, Proxy, etc. // It intentionally excludes sensitive fields like Credentials, Proxy, etc.
type AccountSummary struct { type AccountSummary struct {
...@@ -277,23 +347,30 @@ type UserSubscription struct { ...@@ -277,23 +347,30 @@ type UserSubscription struct {
WeeklyUsageUSD float64 `json:"weekly_usage_usd"` WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"` MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
// AdminUserSubscription 是管理员接口使用的订阅 DTO(包含分配信息/备注等字段)。
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
type AdminUserSubscription struct {
UserSubscription
AssignedBy *int64 `json:"assigned_by"` AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"` AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"` Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"` AssignedByUser *User `json:"assigned_by_user,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
} }
type BulkAssignResult struct { type BulkAssignResult struct {
SuccessCount int `json:"success_count"` SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"` FailedCount int `json:"failed_count"`
Subscriptions []UserSubscription `json:"subscriptions"` Subscriptions []AdminUserSubscription `json:"subscriptions"`
Errors []string `json:"errors"` Errors []string `json:"errors"`
} }
// PromoCode 注册优惠码 // PromoCode 注册优惠码
......
...@@ -31,6 +31,7 @@ type GatewayHandler struct { ...@@ -31,6 +31,7 @@ type GatewayHandler struct {
antigravityGatewayService *service.AntigravityGatewayService antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService userService *service.UserService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
usageService *service.UsageService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
maxAccountSwitchesGemini int maxAccountSwitchesGemini int
...@@ -44,6 +45,7 @@ func NewGatewayHandler( ...@@ -44,6 +45,7 @@ func NewGatewayHandler(
userService *service.UserService, userService *service.UserService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
cfg *config.Config, cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
...@@ -64,6 +66,7 @@ func NewGatewayHandler( ...@@ -64,6 +66,7 @@ func NewGatewayHandler(
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
userService: userService, userService: userService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
usageService: usageService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini, maxAccountSwitchesGemini: maxAccountSwitchesGemini,
...@@ -210,17 +213,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -210,17 +213,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() {
if selection.Acquired && selection.ReleaseFunc != nil { interceptType := detectInterceptType(body)
selection.ReleaseFunc() if interceptType != InterceptTypeNone {
} if selection.Acquired && selection.ReleaseFunc != nil {
if reqStream { selection.ReleaseFunc()
sendMockWarmupStream(c, reqModel) }
} else { if reqStream {
sendMockWarmupResponse(c, reqModel) sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
} }
return
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
...@@ -359,17 +365,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -359,17 +365,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() {
if selection.Acquired && selection.ReleaseFunc != nil { interceptType := detectInterceptType(body)
selection.ReleaseFunc() if interceptType != InterceptTypeNone {
} if selection.Acquired && selection.ReleaseFunc != nil {
if reqStream { selection.ReleaseFunc()
sendMockWarmupStream(c, reqModel) }
} else { if reqStream {
sendMockWarmupResponse(c, reqModel) sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
} }
return
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
...@@ -588,7 +597,7 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service ...@@ -588,7 +597,7 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service
return &cloned return &cloned
} }
// Usage handles getting account balance for CC Switch integration // Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage // GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) { func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c) apiKey, ok := middleware2.GetAPIKeyFromContext(c)
...@@ -603,7 +612,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) { ...@@ -603,7 +612,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return return
} }
// 订阅模式:返回订阅限额信息 // Best-effort: 获取用量统计,失败不影响基础响应
var usageData gin.H
if h.usageService != nil {
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err == nil && dashStats != nil {
usageData = gin.H{
"today": gin.H{
"requests": dashStats.TodayRequests,
"input_tokens": dashStats.TodayInputTokens,
"output_tokens": dashStats.TodayOutputTokens,
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
"cache_read_tokens": dashStats.TodayCacheReadTokens,
"total_tokens": dashStats.TodayTokens,
"cost": dashStats.TodayCost,
"actual_cost": dashStats.TodayActualCost,
},
"total": gin.H{
"requests": dashStats.TotalRequests,
"input_tokens": dashStats.TotalInputTokens,
"output_tokens": dashStats.TotalOutputTokens,
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
"cache_read_tokens": dashStats.TotalCacheReadTokens,
"total_tokens": dashStats.TotalTokens,
"cost": dashStats.TotalCost,
"actual_cost": dashStats.TotalActualCost,
},
"average_duration_ms": dashStats.AverageDurationMs,
"rpm": dashStats.Rpm,
"tpm": dashStats.Tpm,
}
}
}
// 订阅模式:返回订阅限额信息 + 用量统计
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware2.GetSubscriptionFromContext(c) subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok { if !ok {
...@@ -612,28 +654,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) { ...@@ -612,28 +654,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
} }
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
c.JSON(http.StatusOK, gin.H{ resp := gin.H{
"isValid": true, "isValid": true,
"planName": apiKey.Group.Name, "planName": apiKey.Group.Name,
"remaining": remaining, "remaining": remaining,
"unit": "USD", "unit": "USD",
}) "subscription": gin.H{
"daily_usage_usd": subscription.DailyUsageUSD,
"weekly_usage_usd": subscription.WeeklyUsageUSD,
"monthly_usage_usd": subscription.MonthlyUsageUSD,
"daily_limit_usd": apiKey.Group.DailyLimitUSD,
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
"expires_at": subscription.ExpiresAt,
},
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
return return
} }
// 余额模式:返回钱包余额 // 余额模式:返回钱包余额 + 用量统计
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return return
} }
c.JSON(http.StatusOK, gin.H{ resp := gin.H{
"isValid": true, "isValid": true,
"planName": "钱包余额", "planName": "钱包余额",
"remaining": latestUser.Balance, "remaining": latestUser.Balance,
"unit": "USD", "unit": "USD",
}) "balance": latestUser.Balance,
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
} }
// calculateSubscriptionRemaining 计算订阅剩余可用额度 // calculateSubscriptionRemaining 计算订阅剩余可用额度
...@@ -835,17 +895,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -835,17 +895,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
} }
} }
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等) // InterceptType 表示请求拦截类型
func isWarmupRequest(body []byte) bool { type InterceptType int
// 快速检查:如果body不包含关键字,直接返回false
const (
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
)
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
func detectInterceptType(body []byte) InterceptType {
// 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body) bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") { hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
return false hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
if !hasSuggestionMode && !hasWarmupKeyword {
return InterceptTypeNone
} }
// 解析完整请求 // 解析请求(只解析一次)
var req struct { var req struct {
Messages []struct { Messages []struct {
Role string `json:"role"`
Content []struct { Content []struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`
...@@ -856,43 +929,71 @@ func isWarmupRequest(body []byte) bool { ...@@ -856,43 +929,71 @@ func isWarmupRequest(body []byte) bool {
} `json:"system"` } `json:"system"`
} }
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
return false return InterceptTypeNone
} }
// 检查 messages 中的标题提示模式 // 检查 SUGGESTION MODE(最后一条 user 消息)
for _, msg := range req.Messages { if hasSuggestionMode && len(req.Messages) > 0 {
for _, content := range msg.Content { lastMsg := req.Messages[len(req.Messages)-1]
if content.Type == "text" { if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") || lastMsg.Content[0].Type == "text" &&
content.Text == "Warmup" { strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
return true return InterceptTypeSuggestionMode
}
}
} }
} }
// 检查 system 中的标题提取模式 // 检查 Warmup 请求
for _, system := range req.System { if hasWarmupKeyword {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") { // 检查 messages 中的标题提示模式
return true for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return InterceptTypeWarmup
}
}
}
}
// 检查 system 中的标题提取模式
for _, sys := range req.System {
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return InterceptTypeWarmup
}
} }
} }
return false return InterceptTypeNone
} }
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截) // sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) { func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no") c.Header("X-Accel-Buffering", "no")
// 根据拦截类型决定响应内容
var msgID string
var outputTokens int
var textDeltas []string
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
outputTokens = 1
textDeltas = []string{""} // 空内容
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
outputTokens = 2
textDeltas = []string{"New", " Conversation"}
}
// Build message_start event with proper JSON marshaling // Build message_start event with proper JSON marshaling
messageStart := map[string]any{ messageStart := map[string]any{
"type": "message_start", "type": "message_start",
"message": map[string]any{ "message": map[string]any{
"id": "msg_mock_warmup", "id": msgID,
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": model, "model": model,
...@@ -907,16 +1008,46 @@ func sendMockWarmupStream(c *gin.Context, model string) { ...@@ -907,16 +1008,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
} }
messageStartJSON, _ := json.Marshal(messageStart) messageStartJSON, _ := json.Marshal(messageStart)
// Build events
events := []string{ events := []string{
`event: message_start` + "\n" + `data: ` + string(messageStartJSON), `event: message_start` + "\n" + `data: ` + string(messageStartJSON),
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`, `event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
} }
// Add text deltas
for _, text := range textDeltas {
delta := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{
"type": "text_delta",
"text": text,
},
}
deltaJSON, _ := json.Marshal(delta)
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
}
// Add final events
messageDelta := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
events = append(events,
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
)
for _, event := range events { for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n") _, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush() c.Writer.Flush()
...@@ -924,18 +1055,32 @@ func sendMockWarmupStream(c *gin.Context, model string) { ...@@ -924,18 +1055,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
} }
} }
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截) // sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) { func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
var msgID, text string
var outputTokens int
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
text = ""
outputTokens = 1
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
text = "New Conversation"
outputTokens = 2
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup", "id": msgID,
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": model, "model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}}, "content": []gin.H{{"type": "text", "text": text}},
"stop_reason": "end_turn", "stop_reason": "end_turn",
"usage": gin.H{ "usage": gin.H{
"input_tokens": 10, "input_tokens": 10,
"output_tokens": 2, "output_tokens": outputTokens,
}, },
}) })
} }
......
//go:build unit
package handler
import (
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractGeminiCLISessionHash(t *testing.T) {
tests := []struct {
name string
body string
privilegedUserID string
wantEmpty bool
wantHash string
}{
{
name: "with privileged-user-id and tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: false,
wantHash: func() string {
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}(),
},
{
name: "without privileged-user-id but with tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "",
wantEmpty: false,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "without tmp dir",
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
{
name: "empty body",
body: "",
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 创建测试上下文
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/test", nil)
if tt.privilegedUserID != "" {
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
}
// 调用函数
result := extractGeminiCLISessionHash(c, []byte(tt.body))
// 验证结果
if tt.wantEmpty {
require.Empty(t, result, "expected empty session hash")
} else {
require.NotEmpty(t, result, "expected non-empty session hash")
require.Equal(t, tt.wantHash, result, "session hash mismatch")
}
})
}
}
func TestGeminiCLITmpDirRegex(t *testing.T) {
tests := []struct {
name string
input string
wantMatch bool
wantHash string
}{
{
name: "valid tmp dir path",
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "valid tmp dir path in text",
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "invalid hash length",
input: "/Users/ianshaw/.gemini/tmp/abc123",
wantMatch: false,
},
{
name: "no tmp dir",
input: "Hello world",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
if tt.wantMatch {
require.NotNil(t, match, "expected regex to match")
require.Len(t, match, 2, "expected 2 capture groups")
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
} else {
require.Nil(t, match, "expected regex not to match")
}
})
}
}
package handler package handler
import ( import (
"bytes"
"context" "context"
"crypto/sha256"
"encoding/hex"
"errors" "errors"
"io" "io"
"log" "log"
"net/http" "net/http"
"regexp"
"strings" "strings"
"time" "time"
...@@ -20,6 +24,17 @@ import ( ...@@ -20,6 +24,17 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies: // GeminiV1BetaListModels proxies:
// GET /v1beta/models // GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
...@@ -215,12 +230,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -215,12 +230,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body) // 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash sessionKey := sessionHash
if sessionHash != "" { if sessionHash != "" {
sessionKey = "gemini:" + sessionHash sessionKey = "gemini:" + sessionHash
} }
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
...@@ -239,6 +268,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -239,6 +268,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
} else if sessionBoundAccountID == 0 {
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID = account.ID
}
// 4) account concurrency slot // 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired { if !selection.Acquired {
...@@ -438,3 +485,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { ...@@ -438,3 +485,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
} }
return false return false
} }
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
// 2. 从 header 中提取 privileged-user-id(UUID)
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 1. 从请求体中提取 tmp 目录哈希
match := geminiCLITmpDirRegex.FindSubmatch(body)
if len(match) < 2 {
return "" // 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash := string(match[1])
// 2. 提取 privileged-user-id
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
// 3. 组合生成最终的 session hash
if privilegedUserID != "" {
// 组合两个标识符:privileged-user-id + tmp 目录哈希
combined := privilegedUserID + ":" + tmpDirHash
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
return tmpDirHash
}
...@@ -10,6 +10,7 @@ type AdminHandlers struct { ...@@ -10,6 +10,7 @@ type AdminHandlers struct {
User *admin.UserHandler User *admin.UserHandler
Group *admin.GroupHandler Group *admin.GroupHandler
Account *admin.AccountHandler Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
OAuth *admin.OAuthHandler OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler
...@@ -33,10 +34,12 @@ type Handlers struct { ...@@ -33,10 +34,12 @@ type Handlers struct {
Usage *UsageHandler Usage *UsageHandler
Redeem *RedeemHandler Redeem *RedeemHandler
Subscription *SubscriptionHandler Subscription *SubscriptionHandler
Announcement *AnnouncementHandler
Admin *AdminHandlers Admin *AdminHandlers
Gateway *GatewayHandler Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler Setting *SettingHandler
Totp *TotpHandler
} }
// BuildInfo contains build-time information // BuildInfo contains build-time information
......
...@@ -192,8 +192,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -192,8 +192,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
// Generate session hash (from header for OpenAI) // Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c) sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
......
...@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool { ...@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool { func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) { switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID": case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
return true return true
} }
if phase == "billing" || phase == "concurrency" { if phase == "billing" || phase == "concurrency" {
...@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message ...@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
} }
} }
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
if settings.IgnoreInvalidApiKeyErrors {
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
return true
}
}
return false return false
} }
...@@ -32,18 +32,24 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -32,18 +32,24 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
} }
response.Success(c, dto.PublicSettings{ response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, PasswordResetEnabled: settings.PasswordResetEnabled,
SiteName: settings.SiteName, TotpEnabled: settings.TotpEnabled,
SiteLogo: settings.SiteLogo, TurnstileEnabled: settings.TurnstileEnabled,
SiteSubtitle: settings.SiteSubtitle, TurnstileSiteKey: settings.TurnstileSiteKey,
APIBaseURL: settings.APIBaseURL, SiteName: settings.SiteName,
ContactInfo: settings.ContactInfo, SiteLogo: settings.SiteLogo,
DocURL: settings.DocURL, SiteSubtitle: settings.SiteSubtitle,
HomeContent: settings.HomeContent, APIBaseURL: settings.APIBaseURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, ContactInfo: settings.ContactInfo,
Version: h.version, DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
}) })
} }
package handler
import (
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TotpHandler handles TOTP-related requests
type TotpHandler struct {
totpService *service.TotpService
}
// NewTotpHandler creates a new TotpHandler
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
return &TotpHandler{
totpService: totpService,
}
}
// TotpStatusResponse represents the TOTP status response
type TotpStatusResponse struct {
Enabled bool `json:"enabled"`
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
FeatureEnabled bool `json:"feature_enabled"`
}
// GetStatus returns the TOTP status for the current user
// GET /api/v1/user/totp/status
func (h *TotpHandler) GetStatus(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := TotpStatusResponse{
Enabled: status.Enabled,
FeatureEnabled: status.FeatureEnabled,
}
if status.EnabledAt != nil {
ts := status.EnabledAt.Unix()
resp.EnabledAt = &ts
}
response.Success(c, resp)
}
// TotpSetupRequest represents the request to initiate TOTP setup
type TotpSetupRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// TotpSetupResponse represents the TOTP setup response
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"`
}
// InitiateSetup starts the TOTP setup process
// POST /api/v1/user/totp/setup
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpSetupRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body (optional params)
req = TotpSetupRequest{}
}
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, TotpSetupResponse{
Secret: result.Secret,
QRCodeURL: result.QRCodeURL,
SetupToken: result.SetupToken,
Countdown: result.Countdown,
})
}
// TotpEnableRequest represents the request to enable TOTP
type TotpEnableRequest struct {
TotpCode string `json:"totp_code" binding:"required,len=6"`
SetupToken string `json:"setup_token" binding:"required"`
}
// Enable completes the TOTP setup
// POST /api/v1/user/totp/enable
func (h *TotpHandler) Enable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpEnableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// TotpDisableRequest represents the request to disable TOTP
type TotpDisableRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// Disable disables TOTP for the current user
// POST /api/v1/user/totp/disable
func (h *TotpHandler) Disable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpDisableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// GetVerificationMethod returns the verification method for TOTP operations
// GET /api/v1/user/totp/verification-method
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
method := h.totpService.GetVerificationMethod(c.Request.Context())
response.Success(c, method)
}
// SendVerifyCode sends an email verification code for TOTP operations
// POST /api/v1/user/totp/send-code
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
...@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) { ...@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return return
} }
// 清空notes字段,普通用户不应看到备注
userData.Notes = ""
response.Success(c, dto.UserFromService(userData)) response.Success(c, dto.UserFromService(userData))
} }
...@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return return
} }
// 清空notes字段,普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, dto.UserFromService(updatedUser)) response.Success(c, dto.UserFromService(updatedUser))
} }
...@@ -13,6 +13,7 @@ func ProvideAdminHandlers( ...@@ -13,6 +13,7 @@ func ProvideAdminHandlers(
userHandler *admin.UserHandler, userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler, groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler, accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
oauthHandler *admin.OAuthHandler, oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler,
...@@ -32,6 +33,7 @@ func ProvideAdminHandlers( ...@@ -32,6 +33,7 @@ func ProvideAdminHandlers(
User: userHandler, User: userHandler,
Group: groupHandler, Group: groupHandler,
Account: accountHandler, Account: accountHandler,
Announcement: announcementHandler,
OAuth: oauthHandler, OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler, OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler, GeminiOAuth: geminiOAuthHandler,
...@@ -66,10 +68,12 @@ func ProvideHandlers( ...@@ -66,10 +68,12 @@ func ProvideHandlers(
usageHandler *UsageHandler, usageHandler *UsageHandler,
redeemHandler *RedeemHandler, redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler, subscriptionHandler *SubscriptionHandler,
announcementHandler *AnnouncementHandler,
adminHandlers *AdminHandlers, adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler, gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
totpHandler *TotpHandler,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
Auth: authHandler, Auth: authHandler,
...@@ -78,10 +82,12 @@ func ProvideHandlers( ...@@ -78,10 +82,12 @@ func ProvideHandlers(
Usage: usageHandler, Usage: usageHandler,
Redeem: redeemHandler, Redeem: redeemHandler,
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Announcement: announcementHandler,
Admin: adminHandlers, Admin: adminHandlers,
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler, Setting: settingHandler,
Totp: totpHandler,
} }
} }
...@@ -94,8 +100,10 @@ var ProviderSet = wire.NewSet( ...@@ -94,8 +100,10 @@ var ProviderSet = wire.NewSet(
NewUsageHandler, NewUsageHandler,
NewRedeemHandler, NewRedeemHandler,
NewSubscriptionHandler, NewSubscriptionHandler,
NewAnnouncementHandler,
NewGatewayHandler, NewGatewayHandler,
NewOpenAIGatewayHandler, NewOpenAIGatewayHandler,
NewTotpHandler,
ProvideSettingHandler, ProvideSettingHandler,
// Admin handlers // Admin handlers
...@@ -103,6 +111,7 @@ var ProviderSet = wire.NewSet( ...@@ -103,6 +111,7 @@ var ProviderSet = wire.NewSet(
admin.NewUserHandler, admin.NewUserHandler,
admin.NewGroupHandler, admin.NewGroupHandler,
admin.NewAccountHandler, admin.NewAccountHandler,
admin.NewAnnouncementHandler,
admin.NewOAuthHandler, admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler, admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler, admin.NewGeminiOAuthHandler,
......
...@@ -7,6 +7,9 @@ import ( ...@@ -7,6 +7,9 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath"
"strconv"
"testing" "testing"
"time" "time"
...@@ -88,6 +91,7 @@ func performRequest(router *gin.Engine) *httptest.ResponseRecorder { ...@@ -88,6 +91,7 @@ func performRequest(router *gin.Engine) *httptest.ResponseRecorder {
func startRedis(t *testing.T, ctx context.Context) *redis.Client { func startRedis(t *testing.T, ctx context.Context) *redis.Client {
t.Helper() t.Helper()
ensureDockerAvailable(t)
redisContainer, err := tcredis.Run(ctx, redisImageTag) redisContainer, err := tcredis.Run(ctx, redisImageTag)
require.NoError(t, err) require.NoError(t, err)
...@@ -112,3 +116,43 @@ func startRedis(t *testing.T, ctx context.Context) *redis.Client { ...@@ -112,3 +116,43 @@ func startRedis(t *testing.T, ctx context.Context) *redis.Client {
return rdb return rdb
} }
func ensureDockerAvailable(t *testing.T) {
t.Helper()
if dockerAvailable() {
return
}
t.Skip("Docker 未启用,跳过依赖 testcontainers 的集成测试")
}
func dockerAvailable() bool {
if os.Getenv("DOCKER_HOST") != "" {
return true
}
socketCandidates := []string{
"/var/run/docker.sock",
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
filepath.Join(userHomeDir(), ".docker", "run", "docker.sock"),
filepath.Join(userHomeDir(), ".docker", "desktop", "docker.sock"),
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
}
for _, socket := range socketCandidates {
if socket == "" {
continue
}
if _, err := os.Stat(socket); err == nil {
return true
}
}
return false
}
func userHomeDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return home
}
...@@ -33,7 +33,7 @@ const ( ...@@ -33,7 +33,7 @@ const (
"https://www.googleapis.com/auth/experimentsandconfigs" "https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent(与 Antigravity-Manager 保持一致) // User-Agent(与 Antigravity-Manager 保持一致)
UserAgent = "antigravity/1.11.9 windows/amd64" UserAgent = "antigravity/1.15.8 windows/amd64"
// Session 过期时间 // Session 过期时间
SessionTTL = 30 * time.Minute SessionTTL = 30 * time.Minute
......
...@@ -369,8 +369,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -369,8 +369,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking, Text: block.Thinking,
Thought: true, Thought: true,
} }
// 保留原有 signature(Claude 模型需要有效的 signature) // signature 处理:
if block.Signature != "" { // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature part.ThoughtSignature = block.Signature
} else if !allowDummyThought { } else if !allowDummyThought {
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。 // Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
...@@ -409,12 +411,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -409,12 +411,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}, },
} }
// tool_use 的 signature 处理: // tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验) // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路) // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if allowDummyThought { if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
part.ThoughtSignature = block.Signature part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part) parts = append(parts, part)
......
...@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { ...@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"} {"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]` ]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) { t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true) parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil { if err != nil {
...@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { ...@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil { if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts) t.Fatalf("expected 1 functionCall part, got %+v", parts)
} }
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
contentNoSig := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature { if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature) t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
} }
......
...@@ -20,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, * ...@@ -20,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
v1Resp.Response = directResp v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion v1Resp.ModelVersion = directResp.ModelVersion
} else if len(v1Resp.Response.Candidates) == 0 {
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
var directResp GeminiResponse
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
} }
// 使用处理器转换 // 使用处理器转换
...@@ -174,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) { ...@@ -174,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
p.trailingSignature = "" p.trailingSignature = ""
} }
p.textBuilder += part.Text // 非空 text 带签名 - 特殊处理:先输出 text,再输出空 thinking 块
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
if signature != "" { if signature != "" {
p.flushText() p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "text",
Text: part.Text,
})
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{ p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking", Type: "thinking",
Thinking: "", Thinking: "",
Signature: signature, Signature: signature,
}) })
} else {
// 普通 text (无签名) - 累积到 builder
p.textBuilder += part.Text
} }
} }
} }
......
...@@ -16,14 +16,11 @@ type ModelsListResponse struct { ...@@ -16,14 +16,11 @@ type ModelsListResponse struct {
func DefaultModels() []Model { func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"} methods := []string{"generateContent", "streamGenerateContent"}
return []Model{ return []Model{
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
} }
} }
......
...@@ -12,10 +12,10 @@ type Model struct { ...@@ -12,10 +12,10 @@ type Model struct {
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow. // DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{ var DefaultModels = []Model{
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""}, {ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""}, {ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
} }
// DefaultTestModel is the default model to preselect in test flows. // DefaultTestModel is the default model to preselect in test flows.
......
...@@ -13,20 +13,26 @@ import ( ...@@ -13,20 +13,26 @@ import (
"time" "time"
) )
// Claude OAuth Constants (from CRS project) // Claude OAuth Constants
const ( const (
// OAuth Client ID for Claude // OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints // OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize" AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token" TokenURL = "https://platform.claude.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback" RedirectURI = "https://platform.claude.com/oauth/code/callback"
// Scopes // Scopes - Browser URL (includes org:create_api_key for user authorization)
ScopeProfile = "user:profile" ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Internal API call (org:create_api_key not supported in API)
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Setup token (inference only)
ScopeInference = "user:inference" ScopeInference = "user:inference"
// Code Verifier character set (RFC 7636 compliant)
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
// Session TTL // Session TTL
SessionTTL = 30 * time.Minute SessionTTL = 30 * time.Minute
) )
...@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore { ...@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore {
sessions: make(map[string]*OAuthSession), sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
} }
// Start cleanup goroutine
go store.cleanup() go store.cleanup()
return store return store
} }
...@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) { ...@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
if !ok { if !ok {
return nil, false return nil, false
} }
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL { if time.Since(session.CreatedAt) > SessionTTL {
return nil, false return nil, false
} }
...@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) { ...@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) {
return b, nil return b, nil
} }
// GenerateState generates a random state string for OAuth // GenerateState generates a random state string for OAuth (base64url encoded)
func GenerateState() (string, error) { func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32) bytes, err := GenerateRandomBytes(32)
if err != nil { if err != nil {
return "", err return "", err
} }
return hex.EncodeToString(bytes), nil return base64URLEncode(bytes), nil
} }
// GenerateSessionID generates a unique session ID // GenerateSessionID generates a unique session ID
...@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) { ...@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) {
return hex.EncodeToString(bytes), nil return hex.EncodeToString(bytes), nil
} }
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url) // GenerateCodeVerifier generates a PKCE code verifier using character set method
func GenerateCodeVerifier() (string, error) { func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32) const targetLen = 32
if err != nil { charsetLen := len(codeVerifierCharset)
return "", err limit := 256 - (256 % charsetLen)
result := make([]byte, 0, targetLen)
randBuf := make([]byte, targetLen*2)
for len(result) < targetLen {
if _, err := rand.Read(randBuf); err != nil {
return "", err
}
for _, b := range randBuf {
if int(b) < limit {
result = append(result, codeVerifierCharset[int(b)%charsetLen])
if len(result) >= targetLen {
break
}
}
}
} }
return base64URLEncode(bytes), nil
return base64URLEncode(result), nil
} }
// GenerateCodeChallenge generates a PKCE code challenge using S256 method // GenerateCodeChallenge generates a PKCE code challenge using S256 method
...@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string { ...@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string {
// base64URLEncode encodes bytes to base64url without padding // base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string { func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data) encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=") return strings.TrimRight(encoded, "=")
} }
// BuildAuthorizationURL builds the OAuth authorization URL // BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order
func BuildAuthorizationURL(state, codeChallenge, scope string) string { func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{} encodedRedirectURI := url.QueryEscape(RedirectURI)
params.Set("response_type", "code") encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s",
type TokenRequest struct { AuthorizeURL,
GrantType string `json:"grant_type"` ClientID,
ClientID string `json:"client_id"` encodedRedirectURI,
Code string `json:"code"` encodedScope,
RedirectURI string `json:"redirect_uri"` codeChallenge,
CodeVerifier string `json:"code_verifier"` state,
State string `json:"state"` )
} }
// TokenResponse represents the token response from OAuth provider // TokenResponse represents the token response from OAuth provider
type TokenResponse struct { type TokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"` RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"` Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
Organization *OrgInfo `json:"organization,omitempty"` Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"` Account *AccountInfo `json:"account,omitempty"`
} }
...@@ -205,33 +215,6 @@ type OrgInfo struct { ...@@ -205,33 +215,6 @@ type OrgInfo struct {
// AccountInfo represents account info from OAuth response // AccountInfo represents account info from OAuth response
type AccountInfo struct { type AccountInfo struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
} EmailAddress string `json:"email_address"`
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
} }
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
package response package response
import ( import (
"log"
"math" "math"
"net/http" "net/http"
...@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool { ...@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
} }
statusCode, status := infraerrors.ToHTTP(err) statusCode, status := infraerrors.ToHTTP(err)
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true return true
} }
...@@ -162,11 +169,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) { ...@@ -162,11 +169,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) {
// 支持 page_size 和 limit 两种参数名 // 支持 page_size 和 limit 两种参数名
if ps := c.Query("page_size"); ps != "" { if ps := c.Query("page_size"); ps != "" {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 { if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 {
pageSize = val pageSize = val
} }
} else if l := c.Query("limit"); l != "" { } else if l := c.Query("limit"); l != "" {
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 { if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 {
pageSize = val pageSize = val
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment