Commit 31fe0178 authored by yangjianbo's avatar yangjianbo
Browse files
parents d9e345f2 ba5a0d47
...@@ -277,3 +277,44 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) { ...@@ -277,3 +277,44 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
response.Success(c, stats) response.Success(c, stats)
} }
// GetBalanceHistory handles getting user's balance/concurrency change history
// GET /api/v1/admin/users/:id/balance-history
// Query params:
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
page, pageSize := response.ParsePagination(c)
codeType := c.Query("type")
codes, total, totalRecharged, err := h.adminService.GetUserBalanceHistory(c.Request.Context(), userID, page, pageSize, codeType)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Convert to admin DTO (includes notes field for admin visibility)
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
// Custom response with total_recharged alongside pagination
pages := int((total + int64(pageSize) - 1) / int64(pageSize))
if pages < 1 {
pages = 1
}
response.Success(c, gin.H{
"items": out,
"total": total,
"page": page,
"page_size": pageSize,
"pages": pages,
"total_recharged": totalRecharged,
})
}
package handler
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AnnouncementHandler handles user announcement operations
type AnnouncementHandler struct {
announcementService *service.AnnouncementService
}
// NewAnnouncementHandler creates a new user announcement handler
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
return &AnnouncementHandler{
announcementService: announcementService,
}
}
// List handles listing announcements visible to current user
// GET /api/v1/announcements
func (h *AnnouncementHandler) List(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
unreadOnly := parseBoolQuery(c.Query("unread_only"))
items, err := h.announcementService.ListForUser(c.Request.Context(), subject.UserID, unreadOnly)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.UserAnnouncement, 0, len(items))
for i := range items {
out = append(out, *dto.UserAnnouncementFromService(&items[i]))
}
response.Success(c, out)
}
// MarkRead marks an announcement as read for current user
// POST /api/v1/announcements/:id/read
func (h *AnnouncementHandler) MarkRead(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
if err := h.announcementService.MarkRead(c.Request.Context(), subject.UserID, announcementID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "ok"})
}
func parseBoolQuery(v string) bool {
switch strings.TrimSpace(strings.ToLower(v)) {
case "1", "true", "yes", "y", "on":
return true
default:
return false
}
}
package handler package handler
import ( import (
"log/slog"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
...@@ -13,21 +15,25 @@ import ( ...@@ -13,21 +15,25 @@ import (
// AuthHandler handles authentication-related requests // AuthHandler handles authentication-related requests
type AuthHandler struct { type AuthHandler struct {
cfg *config.Config cfg *config.Config
authService *service.AuthService authService *service.AuthService
userService *service.UserService userService *service.UserService
settingSvc *service.SettingService settingSvc *service.SettingService
promoService *service.PromoService promoService *service.PromoService
redeemService *service.RedeemService
totpService *service.TotpService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler { func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
cfg: cfg, cfg: cfg,
authService: authService, authService: authService,
userService: userService, userService: userService,
settingSvc: settingService, settingSvc: settingService,
promoService: promoService, promoService: promoService,
redeemService: redeemService,
totpService: totpService,
} }
} }
...@@ -37,7 +43,8 @@ type RegisterRequest struct { ...@@ -37,7 +43,8 @@ type RegisterRequest struct {
Password string `json:"password" binding:"required,min=6"` Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"` VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"` TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码 PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
} }
// SendVerifyCodeRequest 发送验证码请求 // SendVerifyCodeRequest 发送验证码请求
...@@ -83,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) { ...@@ -83,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
} }
} }
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode) token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -144,6 +151,100 @@ func (h *AuthHandler) Login(c *gin.Context) { ...@@ -144,6 +151,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
return return
} }
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
if err != nil {
response.InternalError(c, "Failed to create 2FA session")
return
}
response.Success(c, TotpLoginResponse{
Requires2FA: true,
TempToken: tempToken,
UserEmailMasked: service.MaskEmail(user.Email),
})
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
}
// TotpLoginResponse represents the response when 2FA is required
type TotpLoginResponse struct {
Requires2FA bool `json:"requires_2fa"`
TempToken string `json:"temp_token,omitempty"`
UserEmailMasked string `json:"user_email_masked,omitempty"`
}
// Login2FARequest represents the 2FA login request
type Login2FARequest struct {
TempToken string `json:"temp_token" binding:"required"`
TotpCode string `json:"totp_code" binding:"required,len=6"`
}
// Login2FA completes the login with 2FA verification
// POST /api/v1/auth/login/2fa
func (h *AuthHandler) Login2FA(c *gin.Context) {
var req Login2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
slog.Debug("login_2fa_request",
"temp_token_len", len(req.TempToken),
"totp_code_len", len(req.TotpCode))
// Get the login session
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
if err != nil || session == nil {
tokenPrefix := ""
if len(req.TempToken) >= 8 {
tokenPrefix = req.TempToken[:8]
}
slog.Debug("login_2fa_session_invalid",
"temp_token_prefix", tokenPrefix,
"error", err)
response.BadRequest(c, "Invalid or expired 2FA session")
return
}
slog.Debug("login_2fa_session_found",
"user_id", session.UserID,
"email", session.Email)
// Verify the TOTP code
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
slog.Debug("login_2fa_verify_failed",
"user_id", session.UserID,
"error", err)
response.ErrorFrom(c, err)
return
}
// Delete the login session
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Generate the JWT token
token, err := h.authService.GenerateToken(user)
if err != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{ response.Success(c, AuthResponse{
AccessToken: token, AccessToken: token,
TokenType: "Bearer", TokenType: "Bearer",
...@@ -247,3 +348,146 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) { ...@@ -247,3 +348,146 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
BonusAmount: promoCode.BonusAmount, BonusAmount: promoCode.BonusAmount,
}) })
} }
// ValidateInvitationCodeRequest 验证邀请码请求
type ValidateInvitationCodeRequest struct {
Code string `json:"code" binding:"required"`
}
// ValidateInvitationCodeResponse 验证邀请码响应
type ValidateInvitationCodeResponse struct {
Valid bool `json:"valid"`
ErrorCode string `json:"error_code,omitempty"`
}
// ValidateInvitationCode 验证邀请码(公开接口,注册前调用)
// POST /api/v1/auth/validate-invitation-code
func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) {
// 检查邀请码功能是否启用
if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_DISABLED",
})
return
}
var req ValidateInvitationCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 验证邀请码
redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code)
if err != nil {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_NOT_FOUND",
})
return
}
// 检查类型和状态
if redeemCode.Type != service.RedeemTypeInvitation {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_INVALID",
})
return
}
if redeemCode.Status != service.StatusUnused {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_USED",
})
return
}
response.Success(c, ValidateInvitationCodeResponse{
Valid: true,
})
}
// ForgotPasswordRequest 忘记密码请求
type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token"`
}
// ForgotPasswordResponse 忘记密码响应
type ForgotPasswordResponse struct {
Message string `json:"message"`
}
// ForgotPassword 请求密码重置
// POST /api/v1/auth/forgot-password
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
var req ForgotPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ForgotPasswordResponse{
Message: "If your email is registered, you will receive a password reset link shortly.",
})
}
// ResetPasswordRequest 重置密码请求
type ResetPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
Token string `json:"token" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// ResetPasswordResponse 重置密码响应
type ResetPasswordResponse struct {
Message string `json:"message"`
}
// ResetPassword 重置密码
// POST /api/v1/auth/reset-password
func (h *AuthHandler) ResetPassword(c *gin.Context) {
var req ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Reset password
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ResetPasswordResponse{
Message: "Your password has been reset successfully. You can now log in with your new password.",
})
}
package dto
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type Announcement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Status string `json:"status"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *time.Time `json:"starts_at,omitempty"`
EndsAt *time.Time `json:"ends_at,omitempty"`
CreatedBy *int64 `json:"created_by,omitempty"`
UpdatedBy *int64 `json:"updated_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type UserAnnouncement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
StartsAt *time.Time `json:"starts_at,omitempty"`
EndsAt *time.Time `json:"ends_at,omitempty"`
ReadAt *time.Time `json:"read_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func AnnouncementFromService(a *service.Announcement) *Announcement {
if a == nil {
return nil
}
return &Announcement{
ID: a.ID,
Title: a.Title,
Content: a.Content,
Status: a.Status,
Targeting: a.Targeting,
StartsAt: a.StartsAt,
EndsAt: a.EndsAt,
CreatedBy: a.CreatedBy,
UpdatedBy: a.UpdatedBy,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
}
}
func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement {
if a == nil {
return nil
}
return &UserAnnouncement{
ID: a.Announcement.ID,
Title: a.Announcement.Title,
Content: a.Announcement.Content,
StartsAt: a.Announcement.StartsAt,
EndsAt: a.Announcement.EndsAt,
ReadAt: a.ReadAt,
CreatedAt: a.Announcement.CreatedAt,
UpdatedAt: a.Announcement.UpdatedAt,
}
}
...@@ -204,6 +204,17 @@ func AccountFromServiceShallow(a *service.Account) *Account { ...@@ -204,6 +204,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
} }
} }
if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 {
out.ScopeRateLimits = make(map[string]ScopeRateLimitInfo, len(scopeLimits))
now := time.Now()
for scope, remainingSec := range scopeLimits {
out.ScopeRateLimits[scope] = ScopeRateLimitInfo{
ResetAt: now.Add(time.Duration(remainingSec) * time.Second),
RemainingSec: remainingSec,
}
}
}
return out return out
} }
...@@ -321,7 +332,7 @@ func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode { ...@@ -321,7 +332,7 @@ func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
} }
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode { func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
return RedeemCode{ out := RedeemCode{
ID: rc.ID, ID: rc.ID,
Code: rc.Code, Code: rc.Code,
Type: rc.Type, Type: rc.Type,
...@@ -335,6 +346,14 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode { ...@@ -335,6 +346,14 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
User: UserFromServiceShallow(rc.User), User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group), Group: GroupFromServiceShallow(rc.Group),
} }
// For admin_balance/admin_concurrency types, include notes so users can see
// why they were charged or credited by admin
if (rc.Type == "admin_balance" || rc.Type == "admin_concurrency") && rc.Notes != "" {
out.Notes = &rc.Notes
}
return out
} }
// AccountSummaryFromService returns a minimal AccountSummary for usage log display. // AccountSummaryFromService returns a minimal AccountSummary for usage log display.
...@@ -358,6 +377,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ...@@ -358,6 +377,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID, AccountID: l.AccountID,
RequestID: l.RequestID, RequestID: l.RequestID,
Model: l.Model, Model: l.Model,
ReasoningEffort: l.ReasoningEffort,
GroupID: l.GroupID, GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID, SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens, InputTokens: l.InputTokens,
......
...@@ -2,9 +2,13 @@ package dto ...@@ -2,9 +2,13 @@ package dto
// SystemSettings represents the admin settings API response payload. // SystemSettings represents the admin settings API response payload.
type SystemSettings struct { type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
...@@ -23,14 +27,16 @@ type SystemSettings struct { ...@@ -23,14 +27,16 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"` APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"` DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"` HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"` HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
...@@ -54,21 +60,26 @@ type SystemSettings struct { ...@@ -54,21 +60,26 @@ type SystemSettings struct {
} }
type PublicSettings struct { type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"` InvitationCodeEnabled bool `json:"invitation_code_enabled"`
SiteName string `json:"site_name"` TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
SiteLogo string `json:"site_logo"` TurnstileEnabled bool `json:"turnstile_enabled"`
SiteSubtitle string `json:"site_subtitle"` TurnstileSiteKey string `json:"turnstile_site_key"`
APIBaseURL string `json:"api_base_url"` SiteName string `json:"site_name"`
ContactInfo string `json:"contact_info"` SiteLogo string `json:"site_logo"`
DocURL string `json:"doc_url"` SiteSubtitle string `json:"site_subtitle"`
HomeContent string `json:"home_content"` APIBaseURL string `json:"api_base_url"`
HideCcsImportButton bool `json:"hide_ccs_import_button"` ContactInfo string `json:"contact_info"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` DocURL string `json:"doc_url"`
Version string `json:"version"` HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
} }
// StreamTimeoutSettings 流超时处理配置 DTO // StreamTimeoutSettings 流超时处理配置 DTO
......
...@@ -2,6 +2,11 @@ package dto ...@@ -2,6 +2,11 @@ package dto
import "time" import "time"
type ScopeRateLimitInfo struct {
ResetAt time.Time `json:"reset_at"`
RemainingSec int64 `json:"remaining_sec"`
}
type User struct { type User struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Email string `json:"email"` Email string `json:"email"`
...@@ -108,6 +113,9 @@ type Account struct { ...@@ -108,6 +113,9 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"` OverloadUntil *time.Time `json:"overload_until"`
// Antigravity scope 级限流状态(从 extra 提取)
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"` TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"` TempUnschedulableReason string `json:"temp_unschedulable_reason"`
...@@ -198,6 +206,10 @@ type RedeemCode struct { ...@@ -198,6 +206,10 @@ type RedeemCode struct {
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"` ValidityDays int `json:"validity_days"`
// Notes is only populated for admin_balance/admin_concurrency types
// so users can see why they were charged or credited
Notes *string `json:"notes,omitempty"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"` Group *Group `json:"group,omitempty"`
} }
...@@ -218,6 +230,9 @@ type UsageLog struct { ...@@ -218,6 +230,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Model string `json:"model"` Model string `json:"model"`
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable.
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"` SubscriptionID *int64 `json:"subscription_id"`
......
...@@ -30,6 +30,7 @@ type GatewayHandler struct { ...@@ -30,6 +30,7 @@ type GatewayHandler struct {
antigravityGatewayService *service.AntigravityGatewayService antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService userService *service.UserService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
usageService *service.UsageService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
maxAccountSwitchesGemini int maxAccountSwitchesGemini int
...@@ -43,6 +44,7 @@ func NewGatewayHandler( ...@@ -43,6 +44,7 @@ func NewGatewayHandler(
userService *service.UserService, userService *service.UserService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
cfg *config.Config, cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
...@@ -63,6 +65,7 @@ func NewGatewayHandler( ...@@ -63,6 +65,7 @@ func NewGatewayHandler(
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
userService: userService, userService: userService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
usageService: usageService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini, maxAccountSwitchesGemini: maxAccountSwitchesGemini,
...@@ -524,7 +527,7 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) { ...@@ -524,7 +527,7 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
}) })
} }
// Usage handles getting account balance for CC Switch integration // Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage // GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) { func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c) apiKey, ok := middleware2.GetAPIKeyFromContext(c)
...@@ -539,7 +542,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) { ...@@ -539,7 +542,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return return
} }
// 订阅模式:返回订阅限额信息 // Best-effort: 获取用量统计,失败不影响基础响应
var usageData gin.H
if h.usageService != nil {
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err == nil && dashStats != nil {
usageData = gin.H{
"today": gin.H{
"requests": dashStats.TodayRequests,
"input_tokens": dashStats.TodayInputTokens,
"output_tokens": dashStats.TodayOutputTokens,
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
"cache_read_tokens": dashStats.TodayCacheReadTokens,
"total_tokens": dashStats.TodayTokens,
"cost": dashStats.TodayCost,
"actual_cost": dashStats.TodayActualCost,
},
"total": gin.H{
"requests": dashStats.TotalRequests,
"input_tokens": dashStats.TotalInputTokens,
"output_tokens": dashStats.TotalOutputTokens,
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
"cache_read_tokens": dashStats.TotalCacheReadTokens,
"total_tokens": dashStats.TotalTokens,
"cost": dashStats.TotalCost,
"actual_cost": dashStats.TotalActualCost,
},
"average_duration_ms": dashStats.AverageDurationMs,
"rpm": dashStats.Rpm,
"tpm": dashStats.Tpm,
}
}
}
// 订阅模式:返回订阅限额信息 + 用量统计
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware2.GetSubscriptionFromContext(c) subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok { if !ok {
...@@ -548,28 +584,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) { ...@@ -548,28 +584,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
} }
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
c.JSON(http.StatusOK, gin.H{ resp := gin.H{
"isValid": true, "isValid": true,
"planName": apiKey.Group.Name, "planName": apiKey.Group.Name,
"remaining": remaining, "remaining": remaining,
"unit": "USD", "unit": "USD",
}) "subscription": gin.H{
"daily_usage_usd": subscription.DailyUsageUSD,
"weekly_usage_usd": subscription.WeeklyUsageUSD,
"monthly_usage_usd": subscription.MonthlyUsageUSD,
"daily_limit_usd": apiKey.Group.DailyLimitUSD,
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
"expires_at": subscription.ExpiresAt,
},
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
return return
} }
// 余额模式:返回钱包余额 // 余额模式:返回钱包余额 + 用量统计
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return return
} }
c.JSON(http.StatusOK, gin.H{ resp := gin.H{
"isValid": true, "isValid": true,
"planName": "钱包余额", "planName": "钱包余额",
"remaining": latestUser.Balance, "remaining": latestUser.Balance,
"unit": "USD", "unit": "USD",
}) "balance": latestUser.Balance,
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
} }
// calculateSubscriptionRemaining 计算订阅剩余可用额度 // calculateSubscriptionRemaining 计算订阅剩余可用额度
...@@ -725,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -725,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return return
} }
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body) parsedReq, err := service.ParseGatewayRequest(body)
......
//go:build unit
package handler
import (
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractGeminiCLISessionHash(t *testing.T) {
tests := []struct {
name string
body string
privilegedUserID string
wantEmpty bool
wantHash string
}{
{
name: "with privileged-user-id and tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: false,
wantHash: func() string {
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}(),
},
{
name: "without privileged-user-id but with tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "",
wantEmpty: false,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "without tmp dir",
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
{
name: "empty body",
body: "",
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 创建测试上下文
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/test", nil)
if tt.privilegedUserID != "" {
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
}
// 调用函数
result := extractGeminiCLISessionHash(c, []byte(tt.body))
// 验证结果
if tt.wantEmpty {
require.Empty(t, result, "expected empty session hash")
} else {
require.NotEmpty(t, result, "expected non-empty session hash")
require.Equal(t, tt.wantHash, result, "session hash mismatch")
}
})
}
}
func TestGeminiCLITmpDirRegex(t *testing.T) {
tests := []struct {
name string
input string
wantMatch bool
wantHash string
}{
{
name: "valid tmp dir path",
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "valid tmp dir path in text",
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "invalid hash length",
input: "/Users/ianshaw/.gemini/tmp/abc123",
wantMatch: false,
},
{
name: "no tmp dir",
input: "Hello world",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
if tt.wantMatch {
require.NotNil(t, match, "expected regex to match")
require.Len(t, match, 2, "expected 2 capture groups")
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
} else {
require.Nil(t, match, "expected regex not to match")
}
})
}
}
package handler package handler
import ( import (
"bytes"
"context" "context"
"crypto/sha256"
"encoding/hex"
"errors" "errors"
"io" "io"
"log" "log"
"net/http" "net/http"
"regexp"
"strings" "strings"
"time" "time"
...@@ -19,6 +23,17 @@ import ( ...@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies: // GeminiV1BetaListModels proxies:
// GET /v1beta/models // GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
...@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body) // 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash sessionKey := sessionHash
if sessionHash != "" { if sessionHash != "" {
sessionKey = "gemini:" + sessionHash sessionKey = "gemini:" + sessionHash
} }
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
...@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
} else if sessionBoundAccountID == 0 {
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID = account.ID
}
// 4) account concurrency slot // 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired { if !selection.Acquired {
...@@ -319,18 +366,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -319,18 +366,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// 6) record usage async // 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
APIKey: apiKey, Result: result,
User: apiKey.User, APIKey: apiKey,
Account: usedAccount, User: apiKey.User,
Subscription: subscription, Account: usedAccount,
UserAgent: ua, Subscription: subscription,
IPAddress: ip, UserAgent: ua,
IPAddress: ip,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
...@@ -433,3 +483,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { ...@@ -433,3 +483,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
} }
return false return false
} }
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
// 2. 从 header 中提取 privileged-user-id(UUID)
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 1. 从请求体中提取 tmp 目录哈希
match := geminiCLITmpDirRegex.FindSubmatch(body)
if len(match) < 2 {
return "" // 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash := string(match[1])
// 2. 提取 privileged-user-id
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
// 3. 组合生成最终的 session hash
if privilegedUserID != "" {
// 组合两个标识符:privileged-user-id + tmp 目录哈希
combined := privilegedUserID + ":" + tmpDirHash
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
return tmpDirHash
}
...@@ -10,6 +10,7 @@ type AdminHandlers struct { ...@@ -10,6 +10,7 @@ type AdminHandlers struct {
User *admin.UserHandler User *admin.UserHandler
Group *admin.GroupHandler Group *admin.GroupHandler
Account *admin.AccountHandler Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
OAuth *admin.OAuthHandler OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler
...@@ -33,10 +34,12 @@ type Handlers struct { ...@@ -33,10 +34,12 @@ type Handlers struct {
Usage *UsageHandler Usage *UsageHandler
Redeem *RedeemHandler Redeem *RedeemHandler
Subscription *SubscriptionHandler Subscription *SubscriptionHandler
Announcement *AnnouncementHandler
Admin *AdminHandlers Admin *AdminHandlers
Gateway *GatewayHandler Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler Setting *SettingHandler
Totp *TotpHandler
} }
// BuildInfo contains build-time information // BuildInfo contains build-time information
......
...@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool { ...@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool { func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) { switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID": case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
return true return true
} }
if phase == "billing" || phase == "concurrency" { if phase == "billing" || phase == "concurrency" {
...@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message ...@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
} }
} }
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
if settings.IgnoreInvalidApiKeyErrors {
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
return true
}
}
return false return false
} }
...@@ -32,20 +32,25 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -32,20 +32,25 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
} }
response.Success(c, dto.PublicSettings{ response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, InvitationCodeEnabled: settings.InvitationCodeEnabled,
SiteName: settings.SiteName, TotpEnabled: settings.TotpEnabled,
SiteLogo: settings.SiteLogo, TurnstileEnabled: settings.TurnstileEnabled,
SiteSubtitle: settings.SiteSubtitle, TurnstileSiteKey: settings.TurnstileSiteKey,
APIBaseURL: settings.APIBaseURL, SiteName: settings.SiteName,
ContactInfo: settings.ContactInfo, SiteLogo: settings.SiteLogo,
DocURL: settings.DocURL, SiteSubtitle: settings.SiteSubtitle,
HomeContent: settings.HomeContent, APIBaseURL: settings.APIBaseURL,
HideCcsImportButton: settings.HideCcsImportButton, ContactInfo: settings.ContactInfo,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, DocURL: settings.DocURL,
Version: h.version, HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
}) })
} }
package handler
import (
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TotpHandler handles TOTP-related requests
type TotpHandler struct {
totpService *service.TotpService
}
// NewTotpHandler creates a new TotpHandler
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
return &TotpHandler{
totpService: totpService,
}
}
// TotpStatusResponse represents the TOTP status response
type TotpStatusResponse struct {
Enabled bool `json:"enabled"`
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
FeatureEnabled bool `json:"feature_enabled"`
}
// GetStatus returns the TOTP status for the current user
// GET /api/v1/user/totp/status
func (h *TotpHandler) GetStatus(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := TotpStatusResponse{
Enabled: status.Enabled,
FeatureEnabled: status.FeatureEnabled,
}
if status.EnabledAt != nil {
ts := status.EnabledAt.Unix()
resp.EnabledAt = &ts
}
response.Success(c, resp)
}
// TotpSetupRequest represents the request to initiate TOTP setup
type TotpSetupRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// TotpSetupResponse represents the TOTP setup response
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"`
}
// InitiateSetup starts the TOTP setup process
// POST /api/v1/user/totp/setup
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpSetupRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body (optional params)
req = TotpSetupRequest{}
}
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, TotpSetupResponse{
Secret: result.Secret,
QRCodeURL: result.QRCodeURL,
SetupToken: result.SetupToken,
Countdown: result.Countdown,
})
}
// TotpEnableRequest represents the request to enable TOTP
type TotpEnableRequest struct {
TotpCode string `json:"totp_code" binding:"required,len=6"`
SetupToken string `json:"setup_token" binding:"required"`
}
// Enable completes the TOTP setup
// POST /api/v1/user/totp/enable
func (h *TotpHandler) Enable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpEnableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// TotpDisableRequest represents the request to disable TOTP
type TotpDisableRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// Disable disables TOTP for the current user
// POST /api/v1/user/totp/disable
func (h *TotpHandler) Disable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpDisableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// GetVerificationMethod returns the verification method for TOTP operations
// GET /api/v1/user/totp/verification-method
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
method := h.totpService.GetVerificationMethod(c.Request.Context())
response.Success(c, method)
}
// SendVerifyCode sends an email verification code for TOTP operations
// POST /api/v1/user/totp/send-code
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
...@@ -13,6 +13,7 @@ func ProvideAdminHandlers( ...@@ -13,6 +13,7 @@ func ProvideAdminHandlers(
userHandler *admin.UserHandler, userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler, groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler, accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
oauthHandler *admin.OAuthHandler, oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler,
...@@ -32,6 +33,7 @@ func ProvideAdminHandlers( ...@@ -32,6 +33,7 @@ func ProvideAdminHandlers(
User: userHandler, User: userHandler,
Group: groupHandler, Group: groupHandler,
Account: accountHandler, Account: accountHandler,
Announcement: announcementHandler,
OAuth: oauthHandler, OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler, OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler, GeminiOAuth: geminiOAuthHandler,
...@@ -66,10 +68,12 @@ func ProvideHandlers( ...@@ -66,10 +68,12 @@ func ProvideHandlers(
usageHandler *UsageHandler, usageHandler *UsageHandler,
redeemHandler *RedeemHandler, redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler, subscriptionHandler *SubscriptionHandler,
announcementHandler *AnnouncementHandler,
adminHandlers *AdminHandlers, adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler, gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
totpHandler *TotpHandler,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
Auth: authHandler, Auth: authHandler,
...@@ -78,10 +82,12 @@ func ProvideHandlers( ...@@ -78,10 +82,12 @@ func ProvideHandlers(
Usage: usageHandler, Usage: usageHandler,
Redeem: redeemHandler, Redeem: redeemHandler,
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Announcement: announcementHandler,
Admin: adminHandlers, Admin: adminHandlers,
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler, Setting: settingHandler,
Totp: totpHandler,
} }
} }
...@@ -94,8 +100,10 @@ var ProviderSet = wire.NewSet( ...@@ -94,8 +100,10 @@ var ProviderSet = wire.NewSet(
NewUsageHandler, NewUsageHandler,
NewRedeemHandler, NewRedeemHandler,
NewSubscriptionHandler, NewSubscriptionHandler,
NewAnnouncementHandler,
NewGatewayHandler, NewGatewayHandler,
NewOpenAIGatewayHandler, NewOpenAIGatewayHandler,
NewTotpHandler,
ProvideSettingHandler, ProvideSettingHandler,
// Admin handlers // Admin handlers
...@@ -103,6 +111,7 @@ var ProviderSet = wire.NewSet( ...@@ -103,6 +111,7 @@ var ProviderSet = wire.NewSet(
admin.NewUserHandler, admin.NewUserHandler,
admin.NewGroupHandler, admin.NewGroupHandler,
admin.NewAccountHandler, admin.NewAccountHandler,
admin.NewAnnouncementHandler,
admin.NewOAuthHandler, admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler, admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler, admin.NewGeminiOAuthHandler,
......
...@@ -33,7 +33,7 @@ const ( ...@@ -33,7 +33,7 @@ const (
"https://www.googleapis.com/auth/experimentsandconfigs" "https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent(与 Antigravity-Manager 保持一致) // User-Agent(与 Antigravity-Manager 保持一致)
UserAgent = "antigravity/1.11.9 windows/amd64" UserAgent = "antigravity/1.15.8 windows/amd64"
// Session 过期时间 // Session 过期时间
SessionTTL = 30 * time.Minute SessionTTL = 30 * time.Minute
......
...@@ -367,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -367,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking, Text: block.Thinking,
Thought: true, Thought: true,
} }
// 保留原有 signature(Claude 模型需要有效的 signature) // signature 处理:
if block.Signature != "" { // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature part.ThoughtSignature = block.Signature
} else if !allowDummyThought { } else if !allowDummyThought {
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。 // Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
...@@ -407,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -407,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}, },
} }
// tool_use 的 signature 处理: // tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验) // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路) // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if allowDummyThought { if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
part.ThoughtSignature = block.Signature part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part) parts = append(parts, part)
......
...@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { ...@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"} {"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]` ]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) { t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true) parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil { if err != nil {
...@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { ...@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil { if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts) t.Fatalf("expected 1 functionCall part, got %+v", parts)
} }
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
contentNoSig := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature { if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature) t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
} }
......
...@@ -9,11 +9,26 @@ const ( ...@@ -9,11 +9,26 @@ const (
BetaClaudeCode = "claude-code-20250219" BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
) )
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
...@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking ...@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。 // DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{ var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)", // Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent": "claude-cli/2.1.22 (external, cli)",
"X-Stainless-Lang": "js", "X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0", "X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux", "X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64", "X-Stainless-Arch": "arm64",
"X-Stainless-Runtime": "node", "X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0", "X-Stainless-Runtime-Version": "v24.13.0",
"X-Stainless-Retry-Count": "0", "X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60", "X-Stainless-Timeout": "600",
"X-App": "cli", "X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true", "Anthropic-Dangerous-Direct-Browser-Access": "true",
} }
...@@ -79,3 +96,39 @@ func DefaultModelIDs() []string { ...@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型 // DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929" const DefaultTestModel = "claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var ModelIDOverrides = map[string]string{
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
"claude-opus-4-5": "claude-opus-4-5-20251101",
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var ModelIDReverseOverrides = map[string]string{
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-opus-4-5-20251101": "claude-opus-4-5",
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func NormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDOverrides[id]; ok {
return mapped
}
return id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func DenormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDReverseOverrides[id]; ok {
return mapped
}
return id
}
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
package response package response
import ( import (
"log"
"math" "math"
"net/http" "net/http"
...@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool { ...@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
} }
statusCode, status := infraerrors.ToHTTP(err) statusCode, status := infraerrors.ToHTTP(err)
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true return true
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment