Commit 3b7a5fff authored by 陈曦's avatar 陈曦
Browse files

补充openai、gemini以及流失请求的采集数据以及nfs落库

parent 8519a8eb
Pipeline #82284 failed with stage
in 2 minutes and 21 seconds
......@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
NewChannelRepository,
NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository,
NewAffiliateRepository,
NewRequestCaptureLogRepository,
// Cache implementations
......
......@@ -715,6 +715,10 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
......@@ -774,6 +778,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
"wechat_connect_app_secret_configured": false,
......@@ -895,6 +900,10 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
......@@ -949,6 +958,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,
......
......@@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
......
......@@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
......@@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
toucher := &recordingActivityToucher{}
......
......@@ -91,6 +91,9 @@ func RegisterAdminRoutes(
// 渠道监控
registerChannelMonitorRoutes(admin, h)
// 邀请返利(专属用户管理)
registerAffiliateRoutes(admin, h)
}
}
......@@ -594,3 +597,18 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
}
}
// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置)
func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
affiliates := admin.Group("/affiliates")
{
users := affiliates.Group("/users")
{
users.GET("", h.Admin.Affiliate.ListUsers)
users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
}
}
}
......@@ -25,6 +25,8 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
user.GET("/aff", h.User.GetAffiliate)
user.POST("/aff/transfer", h.User.TransferAffiliateQuota)
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
......
......@@ -393,6 +393,56 @@ func parseTempUnschedInt(value any) int {
return 0
}
const (
// OpenAICompactModeAuto follows compact-probe results when deciding compact eligibility.
OpenAICompactModeAuto = "auto"
// OpenAICompactModeForceOn always treats the account as compact-supported.
OpenAICompactModeForceOn = "force_on"
// OpenAICompactModeForceOff always treats the account as compact-unsupported.
OpenAICompactModeForceOff = "force_off"
)
func normalizeOpenAICompactMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAICompactModeForceOn:
return OpenAICompactModeForceOn
case OpenAICompactModeForceOff:
return OpenAICompactModeForceOff
default:
return OpenAICompactModeAuto
}
}
func stringMappingFromRaw(raw any) map[string]string {
switch mapping := raw.(type) {
case map[string]any:
if len(mapping) == 0 {
return nil
}
result := make(map[string]string, len(mapping))
for key, value := range mapping {
if str, ok := value.(string); ok {
result[key] = str
}
}
if len(result) == 0 {
return nil
}
return result
case map[string]string:
if len(mapping) == 0 {
return nil
}
result := make(map[string]string, len(mapping))
for key, value := range mapping {
result[key] = value
}
return result
default:
return nil
}
}
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
......@@ -598,6 +648,77 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
return requestedModel, false
}
// GetOpenAICompactMode returns the compact routing mode for an OpenAI account.
// Missing or invalid values fall back to "auto".
func (a *Account) GetOpenAICompactMode() string {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return OpenAICompactModeAuto
}
mode, _ := a.Extra["openai_compact_mode"].(string)
return normalizeOpenAICompactMode(mode)
}
// OpenAICompactSupportKnown reports whether compact capability is known for this
// account and, when known, whether it is supported.
func (a *Account) OpenAICompactSupportKnown() (supported bool, known bool) {
if a == nil || !a.IsOpenAI() {
return false, false
}
switch a.GetOpenAICompactMode() {
case OpenAICompactModeForceOn:
return true, true
case OpenAICompactModeForceOff:
return false, true
}
if a.Extra == nil {
return false, false
}
supported, ok := a.Extra["openai_compact_supported"].(bool)
if !ok {
return false, false
}
return supported, true
}
// AllowsOpenAICompact reports whether the account may be considered for compact
// requests. Unknown capability remains allowed to avoid breaking older accounts
// before an explicit probe has been run.
func (a *Account) AllowsOpenAICompact() bool {
if a == nil || !a.IsOpenAI() {
return false
}
supported, known := a.OpenAICompactSupportKnown()
if !known {
return true
}
return supported
}
// GetCompactModelMapping returns compact-only model remapping configuration.
// This mapping is intended for /responses/compact only and does not affect
// normal /responses traffic.
func (a *Account) GetCompactModelMapping() map[string]string {
if a == nil || a.Credentials == nil {
return nil
}
return stringMappingFromRaw(a.Credentials["compact_model_mapping"])
}
// ResolveCompactMappedModel resolves compact-only model remapping and reports
// whether a compact-specific mapping rule matched.
func (a *Account) ResolveCompactMappedModel(requestedModel string) (mappedModel string, matched bool) {
mapping := a.GetCompactModelMapping()
if len(mapping) == 0 {
return requestedModel, false
}
if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
return mappedModel, true
}
return requestedModel, false
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey {
return ""
......
package service
import "testing"
func TestAccountGetOpenAICompactMode(t *testing.T) {
tests := []struct {
name string
account *Account
want string
}{
{
name: "nil account defaults to auto",
want: OpenAICompactModeAuto,
},
{
name: "non openai account defaults to auto",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
},
want: OpenAICompactModeAuto,
},
{
name: "missing extra defaults to auto",
account: &Account{
Platform: PlatformOpenAI,
},
want: OpenAICompactModeAuto,
},
{
name: "invalid mode falls back to auto",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": " invalid "},
},
want: OpenAICompactModeAuto,
},
{
name: "force on is normalized",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": " FORCE_ON "},
},
want: OpenAICompactModeForceOn,
},
{
name: "force off is normalized",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": "force_off"},
},
want: OpenAICompactModeForceOff,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.account.GetOpenAICompactMode(); got != tt.want {
t.Fatalf("GetOpenAICompactMode() = %q, want %q", got, tt.want)
}
})
}
}
func TestAccountOpenAICompactSupportKnown(t *testing.T) {
tests := []struct {
name string
account *Account
wantSupported bool
wantKnown bool
}{
{
name: "nil account is unknown",
wantSupported: false,
wantKnown: false,
},
{
name: "non openai account is unknown",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{"openai_compact_supported": true},
},
wantSupported: false,
wantKnown: false,
},
{
name: "force on overrides probe state",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{
"openai_compact_mode": OpenAICompactModeForceOn,
"openai_compact_supported": false,
},
},
wantSupported: true,
wantKnown: true,
},
{
name: "force off overrides probe state",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{
"openai_compact_mode": OpenAICompactModeForceOff,
"openai_compact_supported": true,
},
},
wantSupported: false,
wantKnown: true,
},
{
name: "auto true is known supported",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": true},
},
wantSupported: true,
wantKnown: true,
},
{
name: "auto false is known unsupported",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": false},
},
wantSupported: false,
wantKnown: true,
},
{
name: "auto without probe state remains unknown",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{},
},
wantSupported: false,
wantKnown: false,
},
{
name: "invalid probe field remains unknown",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": "true"},
},
wantSupported: false,
wantKnown: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotSupported, gotKnown := tt.account.OpenAICompactSupportKnown()
if gotSupported != tt.wantSupported || gotKnown != tt.wantKnown {
t.Fatalf("OpenAICompactSupportKnown() = (%v, %v), want (%v, %v)", gotSupported, gotKnown, tt.wantSupported, tt.wantKnown)
}
})
}
}
func TestAccountAllowsOpenAICompact(t *testing.T) {
tests := []struct {
name string
account *Account
want bool
}{
{
name: "nil account does not allow compact",
want: false,
},
{
name: "non openai account does not allow compact",
account: &Account{
Platform: PlatformAnthropic,
},
want: false,
},
{
name: "unknown openai account remains allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{},
},
want: true,
},
{
name: "supported openai account is allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": true},
},
want: true,
},
{
name: "unsupported openai account is rejected",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_supported": false},
},
want: false,
},
{
name: "force on is allowed",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
},
want: true,
},
{
name: "force off is rejected",
account: &Account{
Platform: PlatformOpenAI,
Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.account.AllowsOpenAICompact(); got != tt.want {
t.Fatalf("AllowsOpenAICompact() = %v, want %v", got, tt.want)
}
})
}
}
func TestAccountGetCompactModelMapping(t *testing.T) {
tests := []struct {
name string
account *Account
want map[string]string
}{
{
name: "nil account returns nil",
want: nil,
},
{
name: "missing credentials returns nil",
account: &Account{
Platform: PlatformOpenAI,
},
want: nil,
},
{
name: "map any is converted",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
"invalid": 1,
},
},
},
want: map[string]string{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
{
name: "map string string is copied",
account: &Account{
Credentials: map[string]any{
"compact_model_mapping": map[string]string{
"gpt-*": "compact-*",
},
},
},
want: map[string]string{
"gpt-*": "compact-*",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.GetCompactModelMapping()
if !equalStringMap(got, tt.want) {
t.Fatalf("GetCompactModelMapping() = %#v, want %#v", got, tt.want)
}
})
}
}
func TestAccountResolveCompactMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expectedModel string
expectedMatch bool
}{
{
name: "no compact mapping reports unmatched",
credentials: nil,
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
{
name: "exact compact mapping matches",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4-openai-compact",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4-openai-compact",
expectedMatch: true,
},
{
name: "exact passthrough counts as match",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "longest wildcard wins",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-*": "fallback-compact",
"gpt-5.4*": "gpt-5.4-openai-compact",
"gpt-5.4-mini*": "gpt-5.4-mini-openai-compact",
},
},
requestedModel: "gpt-5.4-mini",
expectedModel: "gpt-5.4-mini-openai-compact",
expectedMatch: true,
},
{
name: "missing compact mapping reports unmatched",
credentials: map[string]any{
"compact_model_mapping": map[string]any{
"gpt-5.3": "gpt-5.3-openai-compact",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Credentials: tt.credentials,
}
gotModel, gotMatch := account.ResolveCompactMappedModel(tt.requestedModel)
if gotModel != tt.expectedModel || gotMatch != tt.expectedMatch {
t.Fatalf("ResolveCompactMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, gotModel, gotMatch, tt.expectedModel, tt.expectedMatch)
}
})
}
}
func equalStringMap(left, right map[string]string) bool {
if len(left) != len(right) {
return false
}
for key, want := range right {
if got, ok := left[key]; !ok || got != want {
return false
}
}
return true
}
......@@ -165,7 +165,8 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
// Get account
......@@ -176,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// Route to platform-specific test method
if account.IsOpenAI() {
return s.testOpenAIAccountConnection(c, account, modelID, prompt)
return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode))
}
if account.IsGemini() {
......@@ -416,9 +417,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
_ = prompt
mode = normalizeAccountTestMode(mode)
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
......@@ -426,14 +428,12 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
testModelID = openai.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
// Align test routing with gateway behavior: OpenAI accounts apply normal
// account model mapping, and compact mode applies compact-only mapping on top.
testModelID = account.GetMappedModel(testModelID)
if mode == AccountTestModeCompact {
testModelID = resolveOpenAICompactForwardModel(account, testModelID)
return s.testOpenAICompactConnection(c, account, testModelID)
}
// Route to image generation test if an image model is selected
......@@ -538,6 +538,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusTooManyRequests {
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
}
// 401 Unauthorized: 标记账号为永久错误
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
......@@ -550,6 +553,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return s.processOpenAIStream(c, resp.Body)
}
// testOpenAICompactConnection probes /responses/compact and persists the
// resulting capability state on the account.
func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error {
ctx := c.Request.Context()
authToken := ""
apiURL := ""
isOAuth := false
chatgptAccountID := ""
switch {
case account.IsOAuth():
isOAuth = true
authToken = account.GetOpenAIAccessToken()
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
}
apiURL = chatgptCodexAPIURL + "/compact"
chatgptAccountID = account.GetChatGPTAccountID()
case account.Type == AccountTypeAPIKey:
authToken = account.GetOpenAIApiKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = appendOpenAIResponsesRequestPathSuffix(buildOpenAIResponsesURL(normalizedBaseURL), "/compact")
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payloadBytes, _ := json.Marshal(createOpenAICompactProbePayload(testModelID))
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("Originator", "codex_cli_rs")
req.Header.Set("User-Agent", codexCLIUserAgent)
req.Header.Set("Version", codexCLIVersion)
probeSessionID := compactProbeSessionID(account.ID)
req.Header.Set("Session_ID", probeSessionID)
req.Header.Set("Conversation_ID", probeSessionID)
if isOAuth {
req.Host = "chatgpt.com"
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
if s.accountRepo != nil {
updates := buildOpenAICompactProbeExtraUpdates(nil, nil, err, time.Now())
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if s.accountRepo != nil {
updates := buildOpenAICompactProbeExtraUpdates(resp, body, nil, time.Now())
if codexUpdates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(codexUpdates) > 0 {
updates = mergeExtraUpdates(updates, codexUpdates)
}
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
// 探测如返回 429,主动同步限流状态,避免后续短时间内继续选中。
if resp.StatusCode == http.StatusTooManyRequests {
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
}
}
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Compact probe succeeded"})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) {
if s == nil || s.accountRepo == nil || account == nil {
return
}
var resetAt *time.Time
if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
resetAt = calculated
} else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil {
t := time.Unix(*unixTs, 0)
resetAt = &t
}
if resetAt == nil {
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
return
}
now := time.Now()
account.RateLimitedAt = &now
account.RateLimitResetAt = resetAt
if account.Status == StatusError {
if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
return
}
account.Status = StatusActive
account.ErrorMessage = ""
}
}
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
......@@ -994,13 +1145,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
seenCompleted := false
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
if seenCompleted {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
......@@ -1012,8 +1167,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
if seenCompleted {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
var data map[string]any
......@@ -1029,9 +1187,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
if delta, ok := data["delta"].(string); ok && delta != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
}
case "response.completed":
case "response.completed", "response.done":
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
case "response.failed":
errorMsg := "OpenAI response failed"
if responseData, ok := data["response"].(map[string]any); ok {
if errData, ok := responseData["error"].(map[string]any); ok {
if msg, ok := errData["message"].(string); ok && msg != "" {
errorMsg = msg
}
}
}
return s.sendErrorAndEnd(c, errorMsg)
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]any); ok {
......@@ -1261,7 +1429,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault)
finishedAt := time.Now()
body := w.Body.String()
......
package service
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersistsSupport(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-probe"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, chatgptCodexAPIURL+"/compact", upstream.lastReq.URL.String())
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept"))
require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version"))
require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id"))
require.Equal(t, codexCLIUserAgent, upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id"))
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
updates := <-updateCalls
require.Equal(t, true, updates["openai_compact_supported"])
require.Equal(t, http.StatusOK, updates["openai_compact_last_status"])
require.Contains(t, rec.Body.String(), `"type":"test_complete"`)
}
func TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsupported(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 2,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusNotFound,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`404 page not found`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/2/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.Error(t, err)
updates := <-updateCalls
require.Equal(t, false, updates["openai_compact_supported"])
require.Equal(t, http.StatusNotFound, updates["openai_compact_last_status"])
require.Contains(t, rec.Body.String(), `"type":"error"`)
}
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompactPath(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 3,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://example.com/v1",
"compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, "https://example.com/v1/responses/compact", upstream.lastReq.URL.String())
require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
updates := <-updateCalls
require.Equal(t, true, updates["openai_compact_supported"])
}
func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBaseURLUsesV1Path(t *testing.T) {
gin.SetMode(gin.TestMode)
updateCalls := make(chan map[string]any, 1)
account := Account{
ID: 4,
Name: "openai-apikey-default",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
}
repo := &snapshotUpdateAccountRepo{
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
updateExtraCalls: updateCalls,
}
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey_default","status":"completed"}`)),
}}
svc := &AccountTestService{
accountRepo: repo,
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/4/test", bytes.NewReader(nil))
err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
require.NoError(t, err)
require.Equal(t, "https://api.openai.com/v1/responses/compact", upstream.lastReq.URL.String())
<-updateCalls
}
......@@ -61,9 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
clearedErrorID int64
setErrorID int64
setErrorMsg string
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
......@@ -77,6 +80,17 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
return nil
}
func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error {
r.clearedErrorID = id
return nil
}
func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorID = id
r.setErrorMsg = errorMsg
return nil
}
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
......@@ -103,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
......@@ -111,11 +125,36 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require.Contains(t, recorder.Body.String(), "test_complete")
}
func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) {
func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.output_text.delta","delta":"hi"}
`))
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 90,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Contains(t, recorder.Body.String(), "response.completed")
require.NotContains(t, recorder.Body.String(), `"success":true`)
}
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
......@@ -130,15 +169,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
ID: 88,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Equal(t, account.ID, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.Empty(t, account.ErrorMessage)
require.NotNil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 77,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions",
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Equal(t, account.ID, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.Empty(t, account.ErrorMessage)
require.NotNil(t, account.RateLimitResetAt)
require.Empty(t, repo.updatedExtra)
}
func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 78,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Zero(t, repo.clearedErrorID)
require.Equal(t, StatusActive, account.Status)
require.NotNil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 79,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "stale 403",
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Zero(t, repo.rateLimitedID)
require.Nil(t, repo.rateLimitedAt)
require.Zero(t, repo.clearedErrorID)
require.Equal(t, StatusError, account.Status)
require.Equal(t, "stale 403", account.ErrorMessage)
require.Nil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`)
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 80,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.setErrorID)
require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
require.Zero(t, repo.rateLimitedID)
require.Zero(t, repo.clearedErrorID)
require.Nil(t, account.RateLimitResetAt)
}
......@@ -110,7 +110,7 @@ const (
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
openAICodexProbeVersion = "0.104.0"
openAICodexProbeVersion = "0.125.0"
)
// UsageCache 封装账户使用量相关的缓存
......
This diff is collapsed.
//go:build unit
package service
import (
"context"
"math"
"testing"
"github.com/stretchr/testify/require"
)
// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter
// AffRebateRatePercent overrides the global rate, that NULL falls back to the
// global rate, and that out-of-range exclusive rates are clamped silently.
//
// SettingService is left nil here so globalRebateRatePercent returns the
// documented default (AffiliateRebateRateDefault = 20%) — this exercises the
// fallback path without spinning up a settings stub.
func TestResolveRebateRatePercent_PerUserOverride(t *testing.T) {
t.Parallel()
svc := &AffiliateService{}
// nil exclusive rate → falls back to global default (20%)
require.InDelta(t, AffiliateRebateRateDefault,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{}), 1e-9)
// exclusive rate set → overrides global
rate := 50.0
require.InDelta(t, 50.0,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &rate}), 1e-9)
// exclusive rate 0 → returns 0 (no rebate, intentional)
zero := 0.0
require.InDelta(t, 0.0,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &zero}), 1e-9)
// exclusive rate above max → clamped to Max
tooHigh := 250.0
require.InDelta(t, AffiliateRebateRateMax,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooHigh}), 1e-9)
// exclusive rate below min → clamped to Min
tooLow := -5.0
require.InDelta(t, AffiliateRebateRateMin,
svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooLow}), 1e-9)
}
// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled
// safely handles a nil settingService dependency by returning the default
// (off). This protects callers from nil-pointer crashes in misconfigured
// environments.
func TestIsEnabled_NilSettingServiceReturnsDefault(t *testing.T) {
t.Parallel()
svc := &AffiliateService{}
require.False(t, svc.IsEnabled(context.Background()))
require.Equal(t, AffiliateEnabledDefault, svc.IsEnabled(context.Background()))
}
// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by
// admin-facing rate setters: nil is always valid (clear), in-range values
// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest.
func TestValidateExclusiveRate_BoundaryAndInvalid(t *testing.T) {
t.Parallel()
require.NoError(t, validateExclusiveRate(nil))
for _, v := range []float64{0, 0.01, 50, 99.99, 100} {
v := v
require.NoError(t, validateExclusiveRate(&v), "value %v should be valid", v)
}
for _, v := range []float64{-0.01, 100.01, -100, 200} {
v := v
require.Error(t, validateExclusiveRate(&v), "value %v should be rejected", v)
}
nan := math.NaN()
require.Error(t, validateExclusiveRate(&nan))
posInf := math.Inf(1)
require.Error(t, validateExclusiveRate(&posInf))
negInf := math.Inf(-1)
require.Error(t, validateExclusiveRate(&negInf))
}
func TestMaskEmail(t *testing.T) {
t.Parallel()
require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com"))
require.Equal(t, "x***@d***", maskEmail("x@domain"))
require.Equal(t, "", maskEmail(""))
}
func TestIsValidAffiliateCodeFormat(t *testing.T) {
t.Parallel()
// 邀请码格式校验同时服务于:
// 1) 系统自动生成的 12 位随机码(A-Z 去 I/O,2-9 去 0/1)
// 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1")
// 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper)。
cases := []struct {
name string
in string
want bool
}{
{"valid canonical 12-char", "ABCDEFGHJKLM", true},
{"valid all digits 2-9", "234567892345", true},
{"valid mixed", "A2B3C4D5E6F7", true},
{"valid admin custom short", "VIP1", true},
{"valid admin custom with hyphen", "NEW-USER", true},
{"valid admin custom with underscore", "VIP_2026", true},
{"valid 32-char max", "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345", true},
// Previously-excluded chars (I/O/0/1) are now allowed since admins may use them.
{"letter I now allowed", "IBCDEFGHJKLM", true},
{"letter O now allowed", "OBCDEFGHJKLM", true},
{"digit 0 now allowed", "0BCDEFGHJKLM", true},
{"digit 1 now allowed", "1BCDEFGHJKLM", true},
{"too short (3 chars)", "ABC", false},
{"too long (33 chars)", "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456", false},
{"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
{"empty", "", false},
{"utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // bytes out of charset
{"ascii punctuation .", "ABCDEFGHJK.M", false},
{"whitespace", "ABCDEFGHJK M", false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in))
})
}
}
......@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
user *User,
invitationCode string,
signupSource string,
affiliateCode string,
) error {
if s == nil || user == nil || user.ID <= 0 {
return ErrServiceUnavailable
......@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
return nil
}
......
......@@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService(
nil,
nil,
nil,
nil,
)
}
......
This diff is collapsed.
......@@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}
......@@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
},
}
emailService := service.NewEmailService(nil, cache)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
ID: 41,
......
......@@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
values: settings,
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment