Commit 3b7a5fff authored by 陈曦's avatar 陈曦
Browse files

补充openai、gemini以及流失请求的采集数据以及nfs落库

parent 8519a8eb
Pipeline #82284 failed with stage
in 2 minutes and 21 seconds
......@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
NewChannelRepository,
NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository,
NewAffiliateRepository,
NewRequestCaptureLogRepository,
// Cache implementations
......
......@@ -715,6 +715,10 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
......@@ -774,6 +778,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
"wechat_connect_app_secret_configured": false,
......@@ -895,6 +900,10 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
......@@ -949,6 +958,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,
......
......@@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
......
......@@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
......@@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
toucher := &recordingActivityToucher{}
......
......@@ -91,6 +91,9 @@ func RegisterAdminRoutes(
// 渠道监控
registerChannelMonitorRoutes(admin, h)
// 邀请返利(专属用户管理)
registerAffiliateRoutes(admin, h)
}
}
......@@ -594,3 +597,18 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
}
}
// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置)
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
affiliates := admin.Group("/affiliates")
{
users := affiliates.Group("/users")
{
users.GET("", h.Admin.Affiliate.ListUsers)
users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
}
}
}
......@@ -25,6 +25,8 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
user.GET("/aff", h.User.GetAffiliate)
user.POST("/aff/transfer", h.User.TransferAffiliateQuota)
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
......
......@@ -393,6 +393,56 @@ func parseTempUnschedInt(value any) int {
return 0
}
const (
// OpenAICompactModeAuto follows compact-probe results when deciding compact eligibility.
OpenAICompactModeAuto = "auto"
// OpenAICompactModeForceOn always treats the account as compact-supported.
OpenAICompactModeForceOn = "force_on"
// OpenAICompactModeForceOff always treats the account as compact-unsupported.
OpenAICompactModeForceOff = "force_off"
)
func normalizeOpenAICompactMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAICompactModeForceOn:
return OpenAICompactModeForceOn
case OpenAICompactModeForceOff:
return OpenAICompactModeForceOff
default:
return OpenAICompactModeAuto
}
}
func stringMappingFromRaw(raw any) map[string]string {
switch mapping := raw.(type) {
case map[string]any:
if len(mapping) == 0 {
return nil
}
result := make(map[string]string, len(mapping))
for key, value := range mapping {
if str, ok := value.(string); ok {
result[key] = str
}
}
if len(result) == 0 {
return nil
}
return result
case map[string]string:
if len(mapping) == 0 {
return nil
}
result := make(map[string]string, len(mapping))
for key, value := range mapping {
result[key] = value
}
return result
default:
return nil
}
}
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
......@@ -598,6 +648,77 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
return requestedModel, false
}
// GetOpenAICompactMode returns the compact routing mode for an OpenAI account.
// Missing or invalid values fall back to "auto".
func (a *Account) GetOpenAICompactMode() string {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return OpenAICompactModeAuto
}
mode, _ := a.Extra["openai_compact_mode"].(string)
return normalizeOpenAICompactMode(mode)
}
// OpenAICompactSupportKnown reports whether compact capability is known for this
// account and, when known, whether it is supported.
func (a *Account) OpenAICompactSupportKnown() (supported bool, known bool) {
if a == nil || !a.IsOpenAI() {
return false, false
}
switch a.GetOpenAICompactMode() {
case OpenAICompactModeForceOn:
return true, true
case OpenAICompactModeForceOff:
return false, true
}
if a.Extra == nil {
return false, false
}
supported, ok := a.Extra["openai_compact_supported"].(bool)
if !ok {
return false, false
}
return supported, true
}
// AllowsOpenAICompact reports whether the account may be considered for compact
// requests. Unknown capability remains allowed to avoid breaking older accounts
// before an explicit probe has been run.
func (a *Account) AllowsOpenAICompact() bool {
if a == nil || !a.IsOpenAI() {
return false
}
supported, known := a.OpenAICompactSupportKnown()
if !known {
return true
}
return supported
}
// GetCompactModelMapping returns compact-only model remapping configuration.
// This mapping is intended for /responses/compact only and does not affect
// normal /responses traffic.
func (a *Account) GetCompactModelMapping() map[string]string {
if a == nil || a.Credentials == nil {
return nil
}
return stringMappingFromRaw(a.Credentials["compact_model_mapping"])
}
// ResolveCompactMappedModel resolves compact-only model remapping and reports
// whether a compact-specific mapping rule matched.
func (a *Account) ResolveCompactMappedModel(requestedModel string) (mappedModel string, matched bool) {
mapping := a.GetCompactModelMapping()
if len(mapping) == 0 {
return requestedModel, false
}
if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
return mappedModel, true
}
return requestedModel, false
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey {
return ""
......
package service
import "testing"
func TestAccountGetOpenAICompactMode(t *testing.T) {
tests := []struct {
name string
account *Account
want string
}{
{
name: "nil account defaults to auto",
want: OpenAICompactModeAuto,
},
{
name: "non openai account defaults to auto",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
},
want: OpenAICompactModeAuto,
},
{
name: "missing extra defaults to auto",
account: &Account{
Platform: PlatformOpenAI,
},
want: OpenAICompactModeAuto,
},
{
name: "invalid mode falls back to auto",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": " invalid "},
},
want: OpenAICompactModeAuto,
},
{
name: "force on is normalized",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": " FORCE_ON "},
},
want: OpenAICompactModeForceOn,
},
{
name: "force off is normalized",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": "force_off"},
},
want: OpenAICompactModeForceOff,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.account.GetOpenAICompactMode(); got != tt.want {
t.Fatalf("GetOpenAICompactMode() = %q, want %q", got, tt.want)
}
})
}
}
func TestAccountOpenAICompactSupportKnown(t *testing.T) {
tests := []struct {
name string
account *Account
wantSupported bool
wantKnown bool
}{
{
name: "nil account is unknown",
wantSupported: false,
wantKnown: false,
},
{
name: "non openai account is unknown",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{"openai_compact_supported": true},
},
wantSupported: false,
wantKnown: false,
},
{
name: "force on overrides probe state",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{
"openai_compact_mode": OpenAICompactModeForceOn,
"openai_compact_supported": false,
},
},
wantSupported: true,
wantKnown: true,
},
{
name: "force off overrides probe state",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{
"openai_compact_mode": OpenAICompactModeForceOff,
"openai_compact_supported": true,
},
},
wantSupported: false,
wantKnown: true,
},
{
name: "auto true is known supported",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": true},
},
wantSupported: true,
wantKnown: true,
},
{
name: "auto false is known unsupported",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": false},
},
wantSupported: false,
wantKnown: true,
},
{
name: "auto without probe state remains unknown",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{},
},
wantSupported: false,
wantKnown: false,
},
{
name: "invalid probe field remains unknown",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": "true"},
},
wantSupported: false,
wantKnown: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotSupported, gotKnown := tt.account.OpenAICompactSupportKnown()
if gotSupported != tt.wantSupported || gotKnown != tt.wantKnown {
t.Fatalf("OpenAICompactSupportKnown() = (%v, %v), want (%v, %v)", gotSupported, gotKnown, tt.wantSupported, tt.wantKnown)
}
})
}
}
func TestAccountAllowsOpenAICompact(t *testing.T) {
tests := []struct {
name string
account *Account
want bool
}{
{
name: "nil account does not allow compact",
want: false,
},
{
name: "non openai account does not allow compact",
account: &Account{
Platform: PlatformAnthropic,
},
want: false,
},
{
name: "unknown openai account remains allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{},
},
want: true,
},
{
name: "supported openai account is allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": true},
},
want: true,
},
{
name: "unsupported openai account is rejected",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": false},
},
want: false,
},
{
name: "force on is allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
},
want: true,
},
{
name: "force off is rejected",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.account.AllowsOpenAICompact(); got != tt.want {
t.Fatalf("AllowsOpenAICompact() = %v, want %v", got, tt.want)
}
})
}
}
func TestAccountGetCompactModelMapping(t *testing.T) {
tests := []struct {
name string
account *Account
want map[string]string
}{
{
name: "nil account returns nil",
want: nil,
},
{
name: "missing credentials returns nil",
account: &Account{
Platform: PlatformOpenAI,
},
want: nil,
},
{
name: "map any is converted",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
"invalid": 1,
},
},
},
want: map[string]string{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
{
name: "map string string is copied",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]string{
"gpt-*": "compact-*",
},
},
},
want: map[string]string{
"gpt-*": "compact-*",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.GetCompactModelMapping()
if !equalStringMap(got, tt.want) {
t.Fatalf("GetCompactModelMapping() = %#v, want %#v", got, tt.want)
}
})
}
}
func TestAccountResolveCompactMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expectedModel string
expectedMatch bool
}{
{
name: "no compact mapping reports unmatched",
credentials: nil,
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
{
name: "exact compact mapping matches",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4-openai-compact",
expectedMatch: true,
},
{
name: "exact passthrough counts as match",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "longest wildcard wins",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-*": "fallback-compact",
"gpt-5.4*": "gpt-5.4-openai-compact",
"gpt-5.4-mini*": "gpt-5.4-mini-openai-compact",
},
},
requestedModel: "gpt-5.4-mini",
expectedModel: "gpt-5.4-mini-openai-compact",
expectedMatch: true,
},
{
name: "missing compact mapping reports unmatched",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.3": "gpt-5.3-openai-compact",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Credentials: tt.credentials,
}
gotModel, gotMatch := account.ResolveCompactMappedModel(tt.requestedModel)
if gotModel != tt.expectedModel || gotMatch != tt.expectedMatch {
t.Fatalf("ResolveCompactMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, gotModel, gotMatch, tt.expectedModel, tt.expectedMatch)
}
})
}
}
func equalStringMap(left, right map[string]string) bool {
if len(left) != len(right) {
return false
}
for key, want := range right {
if got, ok := left[key]; !ok || got != want {
return false
}
}
return true
}
......@@ -165,7 +165,8 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
// Get account
......@@ -176,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// Route to platform-specific test method
if account.IsOpenAI() {
return s.testOpenAIAccountConnection(c, account, modelID, prompt)
return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode))
}
if account.IsGemini() {
......@@ -416,9 +417,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
_ = prompt
mode = normalizeAccountTestMode(mode)
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
......@@ -426,14 +428,12 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
testModelID = openai.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
// Align test routing with gateway behavior: OpenAI accounts apply normal
// account model mapping, and compact mode applies compact-only mapping on top.
testModelID = account.GetMappedModel(testModelID)
if mode == AccountTestModeCompact {
testModelID = resolveOpenAICompactForwardModel(account, testModelID)
return s.testOpenAICompactConnection(c, account, testModelID)
}
// Route to image generation test if an image model is selected
......@@ -538,6 +538,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusTooManyRequests {
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
}
// 401 Unauthorized: 标记账号为永久错误
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
......@@ -550,6 +553,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return s.processOpenAIStream(c, resp.Body)
}
// testOpenAICompactConnection probes /responses/compact and persists the
// resulting capability state on the account.
func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error {
ctx := c.Request.Context()
authToken := ""
apiURL := ""
isOAuth := false
chatgptAccountID := ""
switch {
case account.IsOAuth():
isOAuth = true
authToken = account.GetOpenAIAccessToken()
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
}
apiURL = chatgptCodexAPIURL + "/compact"
chatgptAccountID = account.GetChatGPTAccountID()
case account.Type == AccountTypeAPIKey:
authToken = account.GetOpenAIApiKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = appendOpenAIResponsesRequestPathSuffix(buildOpenAIResponsesURL(normalizedBaseURL), "/compact")
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payloadBytes, _ := json.Marshal(createOpenAICompactProbePayload(testModelID))
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("Originator", "codex_cli_rs")
req.Header.Set("User-Agent", codexCLIUserAgent)
req.Header.Set("Version", codexCLIVersion)
probeSessionID := compactProbeSessionID(account.ID)
req.Header.Set("Session_ID", probeSessionID)
req.Header.Set("Conversation_ID", probeSessionID)
if isOAuth {
req.Host = "chatgpt.com"
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if s.accountRepo != nil {
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, err, time.Now())
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if s.accountRepo != nil {
updates := buildOpenAICompactProbeExtraUpdates(resp, body, nil, time.Now())
if codexUpdates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(codexUpdates) > 0 {
updates = mergeExtraUpdates(updates, codexUpdates)
}
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
// 探测如返回 429,主动同步限流状态,避免后续短时间内继续选中。
if resp.StatusCode == http.StatusTooManyRequests {
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
}
}
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Compact probe succeeded"})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) {
if s == nil || s.accountRepo == nil || account == nil {
return
}
var resetAt *time.Time
if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
resetAt = calculated
} else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil {
t := time.Unix(*unixTs, 0)
resetAt = &t
}
if resetAt == nil {
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
return
}
now := time.Now()
account.RateLimitedAt = &now
account.RateLimitResetAt = resetAt
if account.Status == StatusError {
if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
return
}
account.Status = StatusActive
account.ErrorMessage = ""
}
}
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
......@@ -994,13 +1145,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
seenCompleted := false
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
if seenCompleted {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
......@@ -1012,8 +1167,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
if seenCompleted {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
var data map[string]any
......@@ -1029,9 +1187,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
if delta, ok := data["delta"].(string); ok && delta != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
}
case "response.completed":
case "response.completed", "response.done":
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
case "response.failed":
errorMsg := "OpenAI response failed"
if responseData, ok := data["response"].(map[string]any); ok {
if errData, ok := responseData["error"].(map[string]any); ok {
if msg, ok := errData["message"].(string); ok && msg != "" {
errorMsg = msg
}
}
}
return s.sendErrorAndEnd(c, errorMsg)
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]any); ok {
......@@ -1261,7 +1429,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault)
finishedAt := time.Now()
body := w.Body.String()
......
package service
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersistsSupport(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-probe"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, chatgptCodexAPIURL+"/compact", upstream.lastReq.URL.String())
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept"))
require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version"))
require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id"))
require.Equal(t, codexCLIUserAgent, upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id"))
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
updates := <-updateCalls
require.Equal(t, true, updates["openai_compact_supported"])
require.Equal(t, http.StatusOK, updates["openai_compact_last_status"])
require.Contains(t, rec.Body.String(), `"type":"test_complete"`)
}
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsupported(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 2,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusNotFound,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`404 page not found`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/2/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.Error(t, err)
updates := <-updateCalls
require.Equal(t, false, updates["openai_compact_supported"])
require.Equal(t, http.StatusNotFound, updates["openai_compact_last_status"])
require.Contains(t, rec.Body.String(), `"type":"error"`)
}
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompactPath(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 3,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://example.com/v1",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, "https://example.com/v1/responses/compact", upstream.lastReq.URL.String())
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
updates := <-updateCalls
require.Equal(t, true, updates["openai_compact_supported"])
}
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBaseURLUsesV1Path(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 4,
Name: "openai-apikey-default",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey_default","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/4/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, "https://api.openai.com/v1/responses/compact", upstream.lastReq.URL.String())
<-updateCalls
}
......@@ -61,9 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
clearedErrorID int64
setErrorID int64
setErrorMsg string
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
......@@ -77,6 +80,17 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
return nil
}
func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error {
r.clearedErrorID = id
return nil
}
func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorID = id
r.setErrorMsg = errorMsg
return nil
}
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
......@@ -103,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
......@@ -111,11 +125,36 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require.Contains(t, recorder.Body.String(), "test_complete")
}
func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) {
func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.output_text.delta","delta":"hi"}
`))
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 90,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Contains(t, recorder.Body.String(), "response.completed")
require.NotContains(t, recorder.Body.String(), `"success":true`)
}
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
......@@ -130,15 +169,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
ID: 88,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Equal(t, account.ID, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.Empty(t, account.ErrorMessage)
require.NotNil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 77,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions",
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Equal(t, account.ID, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.Empty(t, account.ErrorMessage)
require.NotNil(t, account.RateLimitResetAt)
require.Empty(t, repo.updatedExtra)
}
func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 78,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Zero(t, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.NotNil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 79,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "stale 403",
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Zero(t, repo.rateLimitedID)
require.Nil(t, repo.rateLimitedAt)
require.Zero(t, repo.clearedErrorID)
require.Equal(t, StatusError, account.Status)
require.Equal(t, "stale 403", account.ErrorMessage)
require.Nil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 80,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.setErrorID)
require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
require.Zero(t, repo.rateLimitedID)
require.Zero(t, repo.clearedErrorID)
require.Nil(t, account.RateLimitResetAt)
}
......@@ -110,7 +110,7 @@ const (
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
openAICodexProbeVersion = "0.104.0"
openAICodexProbeVersion = "0.125.0"
)
// UsageCache 封装账户使用量相关的缓存
......
package service
import (
"context"
"errors"
"math"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var (
ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found")
ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code")
ErrAffiliateCodeTaken = infraerrors.Conflict("AFFILIATE_CODE_TAKEN", "affiliate code already in use")
ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound")
ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer")
)
const (
affiliateInviteesLimit = 100
// AffiliateCodeMinLength / AffiliateCodeMaxLength bound both system-generated
// 12-char codes and admin-customized codes (e.g. "VIP2026").
AffiliateCodeMinLength = 4
AffiliateCodeMaxLength = 32
)
// affiliateCodeValidChar accepts uppercase letters, digits, underscore and dash.
// All input passes through strings.ToUpper before validation, so lowercase from
// users is normalized — admins may supply mixed case in their UI.
var affiliateCodeValidChar = func() [256]bool {
var tbl [256]bool
for c := byte('A'); c <= 'Z'; c++ {
tbl[c] = true
}
for c := byte('0'); c <= '9'; c++ {
tbl[c] = true
}
tbl['_'] = true
tbl['-'] = true
return tbl
}()
// isValidAffiliateCodeFormat validates code format for both binding (user input)
// and admin updates. Caller is expected to upper-case the input first.
func isValidAffiliateCodeFormat(code string) bool {
if len(code) < AffiliateCodeMinLength || len(code) > AffiliateCodeMaxLength {
return false
}
for i := 0; i < len(code); i++ {
if !affiliateCodeValidChar[code[i]] {
return false
}
}
return true
}
type AffiliateSummary struct {
UserID int64 `json:"user_id"`
AffCode string `json:"aff_code"`
AffCodeCustom bool `json:"aff_code_custom"`
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type AffiliateInvitee struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
CreatedAt *time.Time `json:"created_at,omitempty"`
TotalRebate float64 `json:"total_rebate"`
}
type AffiliateDetail struct {
UserID int64 `json:"user_id"`
AffCode string `json:"aff_code"`
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
// 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
EffectiveRebateRatePercent float64 `json:"effective_rebate_rate_percent"`
Invitees []AffiliateInvitee `json:"invitees"`
}
type AffiliateRepository interface {
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
// 管理端:用户级专属配置
UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error
ResetUserAffCode(ctx context.Context, userID int64) (string, error)
SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
}
// AffiliateAdminFilter 列表筛选条件
type AffiliateAdminFilter struct {
Search string
Page int
PageSize int
}
// AffiliateAdminEntry 专属用户列表条目
type AffiliateAdminEntry struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
AffCode string `json:"aff_code"`
AffCodeCustom bool `json:"aff_code_custom"`
AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
AffCount int `json:"aff_count"`
}
type AffiliateService struct {
repo AffiliateRepository
settingService *SettingService
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCacheService *BillingCacheService
}
func NewAffiliateService(repo AffiliateRepository, settingService *SettingService, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
return &AffiliateService{
repo: repo,
settingService: settingService,
authCacheInvalidator: authCacheInvalidator,
billingCacheService: billingCacheService,
}
}
// IsEnabled reports whether the affiliate (邀请返利) feature is turned on.
func (s *AffiliateService) IsEnabled(ctx context.Context) bool {
if s == nil || s.settingService == nil {
return AffiliateEnabledDefault
}
return s.settingService.IsAffiliateEnabled(ctx)
}
func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
}
if s == nil || s.repo == nil {
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.EnsureUserAffiliate(ctx, userID)
}
func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
// Lazy thaw: move any matured frozen quota to available before reading.
if s != nil && s.repo != nil {
// best-effort: thaw failure is non-fatal
_, _ = s.repo.ThawFrozenQuota(ctx, userID)
}
summary, err := s.EnsureUserAffiliate(ctx, userID)
if err != nil {
return nil, err
}
invitees, err := s.listInvitees(ctx, userID)
if err != nil {
return nil, err
}
return &AffiliateDetail{
UserID: summary.UserID,
AffCode: summary.AffCode,
InviterID: summary.InviterID,
AffCount: summary.AffCount,
AffQuota: summary.AffQuota,
AffFrozenQuota: summary.AffFrozenQuota,
AffHistoryQuota: summary.AffHistoryQuota,
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
Invitees: invitees,
}, nil
}
func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error {
code := strings.ToUpper(strings.TrimSpace(rawCode))
if code == "" {
return nil
}
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
// 总开关关闭时,注册阶段静默忽略 aff 参数(不报错,避免阻断注册流程)
if !s.IsEnabled(ctx) {
return nil
}
if !isValidAffiliateCodeFormat(code) {
return ErrAffiliateCodeInvalid
}
selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID)
if err != nil {
return err
}
if selfSummary.InviterID != nil {
return nil
}
inviterSummary, err := s.repo.GetAffiliateByCode(ctx, code)
if err != nil {
if errors.Is(err, ErrAffiliateProfileNotFound) {
return ErrAffiliateCodeInvalid
}
return err
}
if inviterSummary == nil || inviterSummary.UserID <= 0 || inviterSummary.UserID == userID {
return ErrAffiliateCodeInvalid
}
bound, err := s.repo.BindInviter(ctx, userID, inviterSummary.UserID)
if err != nil {
return err
}
if !bound {
return ErrAffiliateAlreadyBound
}
return nil
}
func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
if s == nil || s.repo == nil {
return 0, nil
}
if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) {
return 0, nil
}
// 总开关关闭时,新充值不再产生返利
if !s.IsEnabled(ctx) {
return 0, nil
}
inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID)
if err != nil {
return 0, err
}
if inviteeSummary.InviterID == nil || *inviteeSummary.InviterID <= 0 {
return 0, nil
}
// 加载邀请人 profile,优先使用专属比例(覆盖全局)
inviterSummary, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID)
if err != nil {
return 0, err
}
// 有效期检查:超过返利有效期后不再产生返利
if s.settingService != nil {
if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
return 0, nil
}
}
}
rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
if rebate <= 0 {
return 0, nil
}
// 单人上限检查:精确截断到剩余额度
if s.settingService != nil {
if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
if err != nil {
return 0, err
}
if existing >= perInviteeCap {
return 0, nil
}
if remaining := perInviteeCap - existing; rebate > remaining {
rebate = roundTo(remaining, 8)
}
}
}
var freezeHours int
if s.settingService != nil {
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
}
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
if err != nil {
return 0, err
}
if !applied {
return 0, nil
}
return rebate, nil
}
// resolveRebateRatePercent returns the inviter's exclusive rate when set,
// otherwise the global setting value (clamped to [Min, Max]).
func (s *AffiliateService) resolveRebateRatePercent(ctx context.Context, inviter *AffiliateSummary) float64 {
if inviter != nil && inviter.AffRebateRatePercent != nil {
v := *inviter.AffRebateRatePercent
if math.IsNaN(v) || math.IsInf(v, 0) {
return s.globalRebateRatePercent(ctx)
}
return clampAffiliateRebateRate(v)
}
return s.globalRebateRatePercent(ctx)
}
// globalRebateRatePercent reads the system-wide rebate rate via SettingService,
// returning the documented default when SettingService is unavailable.
func (s *AffiliateService) globalRebateRatePercent(ctx context.Context) float64 {
if s == nil || s.settingService == nil {
return AffiliateRebateRateDefault
}
return s.settingService.GetAffiliateRebateRatePercent(ctx)
}
func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) {
if s == nil || s.repo == nil {
return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
transferred, balance, err := s.repo.TransferQuotaToBalance(ctx, userID)
if err != nil {
return 0, 0, err
}
if transferred > 0 {
s.invalidateAffiliateCaches(ctx, userID)
}
return transferred, balance, nil
}
func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([]AffiliateInvitee, error) {
if s == nil || s.repo == nil {
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
invitees, err := s.repo.ListInvitees(ctx, inviterID, affiliateInviteesLimit)
if err != nil {
return nil, err
}
for i := range invitees {
invitees[i].Email = maskEmail(invitees[i].Email)
}
return invitees, nil
}
func roundTo(v float64, scale int) float64 {
factor := math.Pow10(scale)
return math.Round(v*factor) / factor
}
func maskEmail(email string) string {
email = strings.TrimSpace(email)
if email == "" {
return ""
}
at := strings.Index(email, "@")
if at <= 0 || at >= len(email)-1 {
return "***"
}
local := email[:at]
domain := email[at+1:]
dot := strings.LastIndex(domain, ".")
maskedLocal := maskSegment(local)
if dot <= 0 || dot >= len(domain)-1 {
return maskedLocal + "@" + maskSegment(domain)
}
domainName := domain[:dot]
tld := domain[dot:]
return maskedLocal + "@" + maskSegment(domainName) + tld
}
func maskSegment(s string) string {
r := []rune(s)
if len(r) == 0 {
return "***"
}
if len(r) == 1 {
return string(r[0]) + "***"
}
return string(r[0]) + "***"
}
func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID int64) {
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService != nil {
if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil {
logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err)
}
}
}
// =========================
// Admin: 专属配置管理
// =========================
// validateExclusiveRate ensures a per-user override is finite and within
// [Min, Max]. nil is always valid (means "clear / fall back to global").
func validateExclusiveRate(ratePercent *float64) error {
if ratePercent == nil {
return nil
}
v := *ratePercent
if math.IsNaN(v) || math.IsInf(v, 0) {
return infraerrors.BadRequest("INVALID_RATE", "invalid rebate rate")
}
if v < AffiliateRebateRateMin || v > AffiliateRebateRateMax {
return infraerrors.BadRequest("INVALID_RATE", "rebate rate out of range")
}
return nil
}
// AdminUpdateUserAffCode 管理员改写用户的邀请码(专属邀请码)。
func (s *AffiliateService) AdminUpdateUserAffCode(ctx context.Context, userID int64, rawCode string) error {
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
code := strings.ToUpper(strings.TrimSpace(rawCode))
if !isValidAffiliateCodeFormat(code) {
return ErrAffiliateCodeInvalid
}
return s.repo.UpdateUserAffCode(ctx, userID, code)
}
// AdminResetUserAffCode 重置用户邀请码为系统随机码。
func (s *AffiliateService) AdminResetUserAffCode(ctx context.Context, userID int64) (string, error) {
if s == nil || s.repo == nil {
return "", infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ResetUserAffCode(ctx, userID)
}
// AdminSetUserRebateRate 设置/清除用户专属返利比例。ratePercent==nil 表示清除。
func (s *AffiliateService) AdminSetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
if err := validateExclusiveRate(ratePercent); err != nil {
return err
}
return s.repo.SetUserRebateRate(ctx, userID, ratePercent)
}
// AdminBatchSetUserRebateRate 批量设置/清除用户专属返利比例。
func (s *AffiliateService) AdminBatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
if err := validateExclusiveRate(ratePercent); err != nil {
return err
}
cleaned := make([]int64, 0, len(userIDs))
for _, uid := range userIDs {
if uid > 0 {
cleaned = append(cleaned, uid)
}
}
if len(cleaned) == 0 {
return nil
}
return s.repo.BatchSetUserRebateRate(ctx, cleaned, ratePercent)
}
// AdminListCustomUsers 列出有专属配置的用户。
func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) {
if s == nil || s.repo == nil {
return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
return s.repo.ListUsersWithCustomSettings(ctx, filter)
}
//go:build unit
package service
import (
"context"
"math"
"testing"
"github.com/stretchr/testify/require"
)
// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter
// AffRebateRatePercent overrides the global rate, that NULL falls back to the
// global rate, and that out-of-range exclusive rates are clamped silently.
//
// SettingService is left nil here so globalRebateRatePercent returns the
// documented default (AffiliateRebateRateDefault = 20%) — this exercises the
// fallback path without spinning up a settings stub.
func TestResolveRebateRatePercent_PerUserOverride(t *testing.T) {
t.Parallel()
svc := &AffiliateService{}
// nil exclusive rate → falls back to global default (20%)
require.InDelta(t, AffiliateRebateRateDefault,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{}), 1e-9)
// exclusive rate set → overrides global
rate := 50.0
require.InDelta(t, 50.0,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &rate}), 1e-9)
// exclusive rate 0 → returns 0 (no rebate, intentional)
zero := 0.0
require.InDelta(t, 0.0,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &zero}), 1e-9)
// exclusive rate above max → clamped to Max
tooHigh := 250.0
require.InDelta(t, AffiliateRebateRateMax,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooHigh}), 1e-9)
// exclusive rate below min → clamped to Min
tooLow := -5.0
require.InDelta(t, AffiliateRebateRateMin,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooLow}), 1e-9)
}
// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled
// safely handles a nil settingService dependency by returning the default
// (off). This protects callers from nil-pointer crashes in misconfigured
// environments.
func TestIsEnabled_NilSettingServiceReturnsDefault(t *testing.T) {
t.Parallel()
svc := &AffiliateService{}
require.False(t, svc.IsEnabled(context.Background()))
require.Equal(t, AffiliateEnabledDefault, svc.IsEnabled(context.Background()))
}
// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by
// admin-facing rate setters: nil is always valid (clear), in-range values
// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest.
func TestValidateExclusiveRate_BoundaryAndInvalid(t *testing.T) {
t.Parallel()
require.NoError(t, validateExclusiveRate(nil))
for _, v := range []float64{0, 0.01, 50, 99.99, 100} {
v := v
require.NoError(t, validateExclusiveRate(&v), "value %v should be valid", v)
}
for _, v := range []float64{-0.01, 100.01, -100, 200} {
v := v
require.Error(t, validateExclusiveRate(&v), "value %v should be rejected", v)
}
nan := math.NaN()
require.Error(t, validateExclusiveRate(&nan))
posInf := math.Inf(1)
require.Error(t, validateExclusiveRate(&posInf))
negInf := math.Inf(-1)
require.Error(t, validateExclusiveRate(&negInf))
}
func TestMaskEmail(t *testing.T) {
t.Parallel()
require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com"))
require.Equal(t, "x***@d***", maskEmail("x@domain"))
require.Equal(t, "", maskEmail(""))
}
func TestIsValidAffiliateCodeFormat(t *testing.T) {
t.Parallel()
// 邀请码格式校验同时服务于:
// 1) 系统自动生成的 12 位随机码(A-Z 去 I/O,2-9 去 0/1)
// 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1")
// 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper)。
cases := []struct {
name string
in string
want bool
}{
{"valid canonical 12-char", "ABCDEFGHJKLM", true},
{"valid all digits 2-9", "234567892345", true},
{"valid mixed", "A2B3C4D5E6F7", true},
{"valid admin custom short", "VIP1", true},
{"valid admin custom with hyphen", "NEW-USER", true},
{"valid admin custom with underscore", "VIP_2026", true},
{"valid 32-char max", "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345", true},
// Previously-excluded chars (I/O/0/1) are now allowed since admins may use them.
{"letter I now allowed", "IBCDEFGHJKLM", true},
{"letter O now allowed", "OBCDEFGHJKLM", true},
{"digit 0 now allowed", "0BCDEFGHJKLM", true},
{"digit 1 now allowed", "1BCDEFGHJKLM", true},
{"too short (3 chars)", "ABC", false},
{"too long (33 chars)", "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456", false},
{"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
{"empty", "", false},
{"utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // bytes out of charset
{"ascii punctuation .", "ABCDEFGHJK.M", false},
{"whitespace", "ABCDEFGHJK M", false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in))
})
}
}
......@@ -21,6 +21,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
......@@ -1739,6 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
var responseBody string
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
......@@ -1749,6 +1751,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
// 流式:从 context buffer 读取采集的文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else {
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
......@@ -1758,6 +1764,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
responseBody = streamRes.responseBody
}
return &ForwardResult{
......@@ -1769,6 +1776,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
ResponseBody: responseBody,
}, nil
}
......@@ -2421,6 +2429,7 @@ handleSuccess:
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
var responseBody string
if stream {
// 客户端要求流式,直接透传
......@@ -2432,6 +2441,10 @@ handleSuccess:
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
// 流式:从 context buffer 读取采集的文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else {
// 客户端要求非流式,收集流式响应后返回
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
......@@ -2441,6 +2454,7 @@ handleSuccess:
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
responseBody = streamRes.responseBody
}
if usage == nil {
......@@ -2465,6 +2479,7 @@ handleSuccess:
ClientDisconnect: clientDisconnect,
ImageCount: imageCount,
ImageSize: imageSize,
ResponseBody: responseBody,
}, nil
}
......@@ -2968,7 +2983,8 @@ func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur
type antigravityStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开
clientDisconnect bool // 客户端是否在流式传输过程中断开
responseBody string // 响应体内容(非流式:完整 JSON;流式:从上下文 buffer 读取)
}
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
......@@ -3124,6 +3140,9 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
// 响应体文本采集:若上下文注入了 ResponseCaptureBuffer,则写入文本内容
captureBuilder, _ := c.Request.Context().Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
......@@ -3197,6 +3216,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs = &ms
}
// 采集文本用于响应捕获
if captureBuilder != nil && len(inner) > 0 {
gjson.GetBytes(inner, "candidates.0.content.parts").ForEach(func(_, v gjson.Result) bool {
if t := v.Get("text").String(); t != "" {
captureBuilder.WriteString(t)
}
return true
})
}
cw.Fprintf("data: %s\n\n", payload)
continue
}
......@@ -3418,7 +3447,7 @@ returnResponse:
}
c.Data(http.StatusOK, "application/json", respBody)
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, responseBody: strings.Join(collectedTextParts, "")}, nil
}
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
......@@ -3867,7 +3896,7 @@ returnResponse:
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, responseBody: string(claudeResp)}, nil
}
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
......@@ -3971,6 +4000,9 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
// 响应体文本采集:若上下文注入了 ResponseCaptureBuffer,则写入文本 delta
captureBuilder, _ := c.Request.Context().Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
......@@ -4024,7 +4056,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
lastDataAt = time.Now()
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
trimmedLine := strings.TrimRight(ev.line, "\r\n")
claudeEvents := processor.ProcessLine(trimmedLine)
if len(claudeEvents) > 0 {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
......@@ -4033,6 +4066,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
cw.Write(claudeEvents)
}
// 采集文本用于响应捕获
if captureBuilder != nil && strings.HasPrefix(trimmedLine, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(trimmedLine, "data:"))
if data != "" && data != "[DONE]" {
// v1internal 格式: response.candidates.0.content.parts.0.text
// 直接 Gemini 格式: candidates.0.content.parts.0.text
text := gjson.Get(data, "response.candidates.0.content.parts.0.text").String()
if text == "" {
text = gjson.Get(data, "candidates.0.content.parts.0.text").String()
}
if text != "" {
captureBuilder.WriteString(text)
}
}
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
......
......@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
user *User,
invitationCode string,
signupSource string,
affiliateCode string,
) error {
if s == nil || user == nil || user.ID <= 0 {
return ErrServiceUnavailable
......@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
return nil
}
......
......@@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService(
nil,
nil,
nil,
nil,
)
}
......
......@@ -72,6 +72,7 @@ type AuthService struct {
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
affiliateService *AffiliateService
defaultSubAssigner DefaultSubscriptionAssigner
}
......@@ -98,6 +99,7 @@ func NewAuthService(
emailQueueService *EmailQueueService,
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
affiliateService *AffiliateService,
) *AuthService {
return &AuthService{
entClient: entClient,
......@@ -110,6 +112,7 @@ func NewAuthService(
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
affiliateService: affiliateService,
defaultSubAssigner: defaultSubAssigner,
}
}
......@@ -123,11 +126,11 @@ func (s *AuthService) EntClient() *dbent.Client {
// Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
}
// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
......@@ -223,6 +226,17 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if s.affiliateService != nil {
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
}
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
// 邀请返利码绑定失败不影响注册,只记录日志
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
}
}
}
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
......@@ -549,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
......@@ -652,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
......@@ -669,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
......@@ -763,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
}
}
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
// for an OAuth-registered user. Failures are logged but never block registration.
func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
if s.affiliateService == nil || userID <= 0 {
return
}
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
}
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
}
}
}
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
......
......@@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}
......@@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
},
}
emailService := service.NewEmailService(nil, cache)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
ID: 41,
......
......@@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
values: settings,
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment