Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
3b7a5fff
Commit
3b7a5fff
authored
Apr 27, 2026
by
陈曦
Browse files
补充openai、gemini以及流失请求的采集数据以及nfs落库
parent
8519a8eb
Pipeline
#82284
failed with stage
in 2 minutes and 21 seconds
Changes
180
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/wire.go
View file @
3b7a5fff
...
@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
...
@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
NewChannelRepository
,
NewChannelRepository
,
NewChannelMonitorRepository
,
NewChannelMonitorRepository
,
NewChannelMonitorRequestTemplateRepository
,
NewChannelMonitorRequestTemplateRepository
,
NewAffiliateRepository
,
NewRequestCaptureLogRepository
,
NewRequestCaptureLogRepository
,
// Cache implementations
// Cache implementations
...
...
backend/internal/server/api_contract_test.go
View file @
3b7a5fff
...
@@ -715,6 +715,10 @@ func TestAPIContracts(t *testing.T) {
...
@@ -715,6 +715,10 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_concurrency": 5,
"default_balance": 1.25,
"default_balance": 1.25,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"default_subscriptions": [],
"enable_model_fallback": false,
"enable_model_fallback": false,
...
@@ -774,6 +778,7 @@ func TestAPIContracts(t *testing.T) {
...
@@ -774,6 +778,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"available_channels_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
"wechat_connect_app_id": "",
"wechat_connect_app_secret_configured": false,
"wechat_connect_app_secret_configured": false,
...
@@ -895,6 +900,10 @@ func TestAPIContracts(t *testing.T) {
...
@@ -895,6 +900,10 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [],
"custom_endpoints": [],
"default_concurrency": 0,
"default_concurrency": 0,
"default_balance": 0,
"default_balance": 0,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"default_subscriptions": [],
"enable_model_fallback": false,
"enable_model_fallback": false,
...
@@ -949,6 +958,7 @@ func TestAPIContracts(t *testing.T) {
...
@@ -949,6 +958,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"available_channels_enabled": false,
"affiliate_enabled": false,
"wechat_connect_enabled": true,
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,
"wechat_connect_app_secret_configured": true,
...
...
backend/internal/server/middleware/admin_auth_test.go
View file @
3b7a5fff
...
@@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
...
@@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
}}
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
}}
authService
:=
service
.
NewAuthService
(
nil
,
nil
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
authService
:=
service
.
NewAuthService
(
nil
,
nil
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
admin
:=
&
service
.
User
{
admin
:=
&
service
.
User
{
ID
:
1
,
ID
:
1
,
...
...
backend/internal/server/middleware/jwt_auth_test.go
View file @
3b7a5fff
...
@@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
...
@@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg
.
JWT
.
AccessTokenExpireMinutes
=
60
cfg
.
JWT
.
AccessTokenExpireMinutes
=
60
userRepo
:=
&
stubJWTUserRepo
{
users
:
users
}
userRepo
:=
&
stubJWTUserRepo
{
users
:
users
}
authSvc
:=
service
.
NewAuthService
(
nil
,
userRepo
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
authSvc
:=
service
.
NewAuthService
(
nil
,
userRepo
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
userSvc
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
,
nil
)
userSvc
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
,
nil
)
mw
:=
NewJWTAuthMiddleware
(
authSvc
,
userSvc
)
mw
:=
NewJWTAuthMiddleware
(
authSvc
,
userSvc
)
...
@@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
...
@@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
cfg
.
JWT
.
AccessTokenExpireMinutes
=
60
cfg
.
JWT
.
AccessTokenExpireMinutes
=
60
userRepo
:=
&
stubJWTUserRepo
{
users
:
map
[
int64
]
*
service
.
User
{
1
:
user
}}
userRepo
:=
&
stubJWTUserRepo
{
users
:
map
[
int64
]
*
service
.
User
{
1
:
user
}}
authSvc
:=
service
.
NewAuthService
(
nil
,
userRepo
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
authSvc
:=
service
.
NewAuthService
(
nil
,
userRepo
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
userSvc
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
,
nil
)
userSvc
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
,
nil
)
toucher
:=
&
recordingActivityToucher
{}
toucher
:=
&
recordingActivityToucher
{}
...
...
backend/internal/server/routes/admin.go
View file @
3b7a5fff
...
@@ -91,6 +91,9 @@ func RegisterAdminRoutes(
...
@@ -91,6 +91,9 @@ func RegisterAdminRoutes(
// 渠道监控
// 渠道监控
registerChannelMonitorRoutes
(
admin
,
h
)
registerChannelMonitorRoutes
(
admin
,
h
)
// 邀请返利(专属用户管理)
registerAffiliateRoutes
(
admin
,
h
)
}
}
}
}
...
@@ -594,3 +597,18 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
...
@@ -594,3 +597,18 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
templates
.
POST
(
"/:id/apply"
,
h
.
Admin
.
ChannelMonitorTemplate
.
Apply
)
templates
.
POST
(
"/:id/apply"
,
h
.
Admin
.
ChannelMonitorTemplate
.
Apply
)
}
}
}
}
// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置)
func
registerAffiliateRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
affiliates
:=
admin
.
Group
(
"/affiliates"
)
{
users
:=
affiliates
.
Group
(
"/users"
)
{
users
.
GET
(
""
,
h
.
Admin
.
Affiliate
.
ListUsers
)
users
.
GET
(
"/lookup"
,
h
.
Admin
.
Affiliate
.
LookupUsers
)
users
.
POST
(
"/batch-rate"
,
h
.
Admin
.
Affiliate
.
BatchSetRate
)
users
.
PUT
(
"/:user_id"
,
h
.
Admin
.
Affiliate
.
UpdateUserSettings
)
users
.
DELETE
(
"/:user_id"
,
h
.
Admin
.
Affiliate
.
ClearUserSettings
)
}
}
}
backend/internal/server/routes/user.go
View file @
3b7a5fff
...
@@ -25,6 +25,8 @@ func RegisterUserRoutes(
...
@@ -25,6 +25,8 @@ func RegisterUserRoutes(
user
.
GET
(
"/profile"
,
h
.
User
.
GetProfile
)
user
.
GET
(
"/profile"
,
h
.
User
.
GetProfile
)
user
.
PUT
(
"/password"
,
h
.
User
.
ChangePassword
)
user
.
PUT
(
"/password"
,
h
.
User
.
ChangePassword
)
user
.
PUT
(
""
,
h
.
User
.
UpdateProfile
)
user
.
PUT
(
""
,
h
.
User
.
UpdateProfile
)
user
.
GET
(
"/aff"
,
h
.
User
.
GetAffiliate
)
user
.
POST
(
"/aff/transfer"
,
h
.
User
.
TransferAffiliateQuota
)
user
.
POST
(
"/account-bindings/email/send-code"
,
h
.
User
.
SendEmailBindingCode
)
user
.
POST
(
"/account-bindings/email/send-code"
,
h
.
User
.
SendEmailBindingCode
)
user
.
POST
(
"/account-bindings/email"
,
h
.
User
.
BindEmailIdentity
)
user
.
POST
(
"/account-bindings/email"
,
h
.
User
.
BindEmailIdentity
)
user
.
DELETE
(
"/account-bindings/:provider"
,
h
.
User
.
UnbindIdentity
)
user
.
DELETE
(
"/account-bindings/:provider"
,
h
.
User
.
UnbindIdentity
)
...
...
backend/internal/service/account.go
View file @
3b7a5fff
...
@@ -393,6 +393,56 @@ func parseTempUnschedInt(value any) int {
...
@@ -393,6 +393,56 @@ func parseTempUnschedInt(value any) int {
return
0
return
0
}
}
const
(
// OpenAICompactModeAuto follows compact-probe results when deciding compact eligibility.
OpenAICompactModeAuto
=
"auto"
// OpenAICompactModeForceOn always treats the account as compact-supported.
OpenAICompactModeForceOn
=
"force_on"
// OpenAICompactModeForceOff always treats the account as compact-unsupported.
OpenAICompactModeForceOff
=
"force_off"
)
func
normalizeOpenAICompactMode
(
mode
string
)
string
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
mode
))
{
case
OpenAICompactModeForceOn
:
return
OpenAICompactModeForceOn
case
OpenAICompactModeForceOff
:
return
OpenAICompactModeForceOff
default
:
return
OpenAICompactModeAuto
}
}
func
stringMappingFromRaw
(
raw
any
)
map
[
string
]
string
{
switch
mapping
:=
raw
.
(
type
)
{
case
map
[
string
]
any
:
if
len
(
mapping
)
==
0
{
return
nil
}
result
:=
make
(
map
[
string
]
string
,
len
(
mapping
))
for
key
,
value
:=
range
mapping
{
if
str
,
ok
:=
value
.
(
string
);
ok
{
result
[
key
]
=
str
}
}
if
len
(
result
)
==
0
{
return
nil
}
return
result
case
map
[
string
]
string
:
if
len
(
mapping
)
==
0
{
return
nil
}
result
:=
make
(
map
[
string
]
string
,
len
(
mapping
))
for
key
,
value
:=
range
mapping
{
result
[
key
]
=
value
}
return
result
default
:
return
nil
}
}
func
(
a
*
Account
)
GetModelMapping
()
map
[
string
]
string
{
func
(
a
*
Account
)
GetModelMapping
()
map
[
string
]
string
{
credentialsPtr
:=
mapPtr
(
a
.
Credentials
)
credentialsPtr
:=
mapPtr
(
a
.
Credentials
)
rawMapping
,
_
:=
a
.
Credentials
[
"model_mapping"
]
.
(
map
[
string
]
any
)
rawMapping
,
_
:=
a
.
Credentials
[
"model_mapping"
]
.
(
map
[
string
]
any
)
...
@@ -598,6 +648,77 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
...
@@ -598,6 +648,77 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
return
requestedModel
,
false
return
requestedModel
,
false
}
}
// GetOpenAICompactMode returns the compact routing mode for an OpenAI account.
// Missing or invalid values fall back to "auto".
func
(
a
*
Account
)
GetOpenAICompactMode
()
string
{
if
a
==
nil
||
!
a
.
IsOpenAI
()
||
a
.
Extra
==
nil
{
return
OpenAICompactModeAuto
}
mode
,
_
:=
a
.
Extra
[
"openai_compact_mode"
]
.
(
string
)
return
normalizeOpenAICompactMode
(
mode
)
}
// OpenAICompactSupportKnown reports whether compact capability is known for this
// account and, when known, whether it is supported.
func
(
a
*
Account
)
OpenAICompactSupportKnown
()
(
supported
bool
,
known
bool
)
{
if
a
==
nil
||
!
a
.
IsOpenAI
()
{
return
false
,
false
}
switch
a
.
GetOpenAICompactMode
()
{
case
OpenAICompactModeForceOn
:
return
true
,
true
case
OpenAICompactModeForceOff
:
return
false
,
true
}
if
a
.
Extra
==
nil
{
return
false
,
false
}
supported
,
ok
:=
a
.
Extra
[
"openai_compact_supported"
]
.
(
bool
)
if
!
ok
{
return
false
,
false
}
return
supported
,
true
}
// AllowsOpenAICompact reports whether the account may be considered for compact
// requests. Unknown capability remains allowed to avoid breaking older accounts
// before an explicit probe has been run.
func
(
a
*
Account
)
AllowsOpenAICompact
()
bool
{
if
a
==
nil
||
!
a
.
IsOpenAI
()
{
return
false
}
supported
,
known
:=
a
.
OpenAICompactSupportKnown
()
if
!
known
{
return
true
}
return
supported
}
// GetCompactModelMapping returns compact-only model remapping configuration.
// This mapping is intended for /responses/compact only and does not affect
// normal /responses traffic.
func
(
a
*
Account
)
GetCompactModelMapping
()
map
[
string
]
string
{
if
a
==
nil
||
a
.
Credentials
==
nil
{
return
nil
}
return
stringMappingFromRaw
(
a
.
Credentials
[
"compact_model_mapping"
])
}
// ResolveCompactMappedModel resolves compact-only model remapping and reports
// whether a compact-specific mapping rule matched.
func
(
a
*
Account
)
ResolveCompactMappedModel
(
requestedModel
string
)
(
mappedModel
string
,
matched
bool
)
{
mapping
:=
a
.
GetCompactModelMapping
()
if
len
(
mapping
)
==
0
{
return
requestedModel
,
false
}
if
mappedModel
,
matched
:=
resolveRequestedModelInMapping
(
mapping
,
requestedModel
);
matched
{
return
mappedModel
,
true
}
return
requestedModel
,
false
}
func
(
a
*
Account
)
GetBaseURL
()
string
{
func
(
a
*
Account
)
GetBaseURL
()
string
{
if
a
.
Type
!=
AccountTypeAPIKey
{
if
a
.
Type
!=
AccountTypeAPIKey
{
return
""
return
""
...
...
backend/internal/service/account_openai_compact_test.go
0 → 100644
View file @
3b7a5fff
package
service
import
"testing"
func
TestAccountGetOpenAICompactMode
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
want
string
}{
{
name
:
"nil account defaults to auto"
,
want
:
OpenAICompactModeAuto
,
},
{
name
:
"non openai account defaults to auto"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
Extra
:
map
[
string
]
any
{
"openai_compact_mode"
:
OpenAICompactModeForceOn
},
},
want
:
OpenAICompactModeAuto
,
},
{
name
:
"missing extra defaults to auto"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
},
want
:
OpenAICompactModeAuto
,
},
{
name
:
"invalid mode falls back to auto"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_mode"
:
" invalid "
},
},
want
:
OpenAICompactModeAuto
,
},
{
name
:
"force on is normalized"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_mode"
:
" FORCE_ON "
},
},
want
:
OpenAICompactModeForceOn
,
},
{
name
:
"force off is normalized"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_mode"
:
"force_off"
},
},
want
:
OpenAICompactModeForceOff
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
if
got
:=
tt
.
account
.
GetOpenAICompactMode
();
got
!=
tt
.
want
{
t
.
Fatalf
(
"GetOpenAICompactMode() = %q, want %q"
,
got
,
tt
.
want
)
}
})
}
}
func
TestAccountOpenAICompactSupportKnown
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
wantSupported
bool
wantKnown
bool
}{
{
name
:
"nil account is unknown"
,
wantSupported
:
false
,
wantKnown
:
false
,
},
{
name
:
"non openai account is unknown"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
Extra
:
map
[
string
]
any
{
"openai_compact_supported"
:
true
},
},
wantSupported
:
false
,
wantKnown
:
false
,
},
{
name
:
"force on overrides probe state"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_mode"
:
OpenAICompactModeForceOn
,
"openai_compact_supported"
:
false
,
},
},
wantSupported
:
true
,
wantKnown
:
true
,
},
{
name
:
"force off overrides probe state"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_mode"
:
OpenAICompactModeForceOff
,
"openai_compact_supported"
:
true
,
},
},
wantSupported
:
false
,
wantKnown
:
true
,
},
{
name
:
"auto true is known supported"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_supported"
:
true
},
},
wantSupported
:
true
,
wantKnown
:
true
,
},
{
name
:
"auto false is known unsupported"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_supported"
:
false
},
},
wantSupported
:
false
,
wantKnown
:
true
,
},
{
name
:
"auto without probe state remains unknown"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{},
},
wantSupported
:
false
,
wantKnown
:
false
,
},
{
name
:
"invalid probe field remains unknown"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_supported"
:
"true"
},
},
wantSupported
:
false
,
wantKnown
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
gotSupported
,
gotKnown
:=
tt
.
account
.
OpenAICompactSupportKnown
()
if
gotSupported
!=
tt
.
wantSupported
||
gotKnown
!=
tt
.
wantKnown
{
t
.
Fatalf
(
"OpenAICompactSupportKnown() = (%v, %v), want (%v, %v)"
,
gotSupported
,
gotKnown
,
tt
.
wantSupported
,
tt
.
wantKnown
)
}
})
}
}
func
TestAccountAllowsOpenAICompact
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
want
bool
}{
{
name
:
"nil account does not allow compact"
,
want
:
false
,
},
{
name
:
"non openai account does not allow compact"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
},
want
:
false
,
},
{
name
:
"unknown openai account remains allowed"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{},
},
want
:
true
,
},
{
name
:
"supported openai account is allowed"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_supported"
:
true
},
},
want
:
true
,
},
{
name
:
"unsupported openai account is rejected"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_supported"
:
false
},
},
want
:
false
,
},
{
name
:
"force on is allowed"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_mode"
:
OpenAICompactModeForceOn
},
},
want
:
true
,
},
{
name
:
"force off is rejected"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
Extra
:
map
[
string
]
any
{
"openai_compact_mode"
:
OpenAICompactModeForceOff
},
},
want
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
if
got
:=
tt
.
account
.
AllowsOpenAICompact
();
got
!=
tt
.
want
{
t
.
Fatalf
(
"AllowsOpenAICompact() = %v, want %v"
,
got
,
tt
.
want
)
}
})
}
}
func
TestAccountGetCompactModelMapping
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
want
map
[
string
]
string
}{
{
name
:
"nil account returns nil"
,
want
:
nil
,
},
{
name
:
"missing credentials returns nil"
,
account
:
&
Account
{
Platform
:
PlatformOpenAI
,
},
want
:
nil
,
},
{
name
:
"map any is converted"
,
account
:
&
Account
{
Credentials
:
map
[
string
]
any
{
"compact_model_mapping"
:
map
[
string
]
any
{
"gpt-5.4"
:
"gpt-5.4-openai-compact"
,
"invalid"
:
1
,
},
},
},
want
:
map
[
string
]
string
{
"gpt-5.4"
:
"gpt-5.4-openai-compact"
,
},
},
{
name
:
"map string string is copied"
,
account
:
&
Account
{
Credentials
:
map
[
string
]
any
{
"compact_model_mapping"
:
map
[
string
]
string
{
"gpt-*"
:
"compact-*"
,
},
},
},
want
:
map
[
string
]
string
{
"gpt-*"
:
"compact-*"
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
tt
.
account
.
GetCompactModelMapping
()
if
!
equalStringMap
(
got
,
tt
.
want
)
{
t
.
Fatalf
(
"GetCompactModelMapping() = %#v, want %#v"
,
got
,
tt
.
want
)
}
})
}
}
func
TestAccountResolveCompactMappedModel
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
credentials
map
[
string
]
any
requestedModel
string
expectedModel
string
expectedMatch
bool
}{
{
name
:
"no compact mapping reports unmatched"
,
credentials
:
nil
,
requestedModel
:
"gpt-5.4"
,
expectedModel
:
"gpt-5.4"
,
expectedMatch
:
false
,
},
{
name
:
"exact compact mapping matches"
,
credentials
:
map
[
string
]
any
{
"compact_model_mapping"
:
map
[
string
]
any
{
"gpt-5.4"
:
"gpt-5.4-openai-compact"
,
},
},
requestedModel
:
"gpt-5.4"
,
expectedModel
:
"gpt-5.4-openai-compact"
,
expectedMatch
:
true
,
},
{
name
:
"exact passthrough counts as match"
,
credentials
:
map
[
string
]
any
{
"compact_model_mapping"
:
map
[
string
]
any
{
"gpt-5.4"
:
"gpt-5.4"
,
},
},
requestedModel
:
"gpt-5.4"
,
expectedModel
:
"gpt-5.4"
,
expectedMatch
:
true
,
},
{
name
:
"longest wildcard wins"
,
credentials
:
map
[
string
]
any
{
"compact_model_mapping"
:
map
[
string
]
any
{
"gpt-*"
:
"fallback-compact"
,
"gpt-5.4*"
:
"gpt-5.4-openai-compact"
,
"gpt-5.4-mini*"
:
"gpt-5.4-mini-openai-compact"
,
},
},
requestedModel
:
"gpt-5.4-mini"
,
expectedModel
:
"gpt-5.4-mini-openai-compact"
,
expectedMatch
:
true
,
},
{
name
:
"missing compact mapping reports unmatched"
,
credentials
:
map
[
string
]
any
{
"compact_model_mapping"
:
map
[
string
]
any
{
"gpt-5.3"
:
"gpt-5.3-openai-compact"
,
},
},
requestedModel
:
"gpt-5.4"
,
expectedModel
:
"gpt-5.4"
,
expectedMatch
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Credentials
:
tt
.
credentials
,
}
gotModel
,
gotMatch
:=
account
.
ResolveCompactMappedModel
(
tt
.
requestedModel
)
if
gotModel
!=
tt
.
expectedModel
||
gotMatch
!=
tt
.
expectedMatch
{
t
.
Fatalf
(
"ResolveCompactMappedModel(%q) = (%q, %v), want (%q, %v)"
,
tt
.
requestedModel
,
gotModel
,
gotMatch
,
tt
.
expectedModel
,
tt
.
expectedMatch
)
}
})
}
}
func
equalStringMap
(
left
,
right
map
[
string
]
string
)
bool
{
if
len
(
left
)
!=
len
(
right
)
{
return
false
}
for
key
,
want
:=
range
right
{
if
got
,
ok
:=
left
[
key
];
!
ok
||
got
!=
want
{
return
false
}
}
return
true
}
backend/internal/service/account_test_service.go
View file @
3b7a5fff
...
@@ -165,7 +165,8 @@ func createTestPayload(modelID string) (map[string]any, error) {
...
@@ -165,7 +165,8 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func
(
s
*
AccountTestService
)
TestAccountConnection
(
c
*
gin
.
Context
,
accountID
int64
,
modelID
string
,
prompt
string
)
error
{
// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path
func
(
s
*
AccountTestService
)
TestAccountConnection
(
c
*
gin
.
Context
,
accountID
int64
,
modelID
string
,
prompt
string
,
mode
string
)
error
{
ctx
:=
c
.
Request
.
Context
()
ctx
:=
c
.
Request
.
Context
()
// Get account
// Get account
...
@@ -176,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
...
@@ -176,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// Route to platform-specific test method
// Route to platform-specific test method
if
account
.
IsOpenAI
()
{
if
account
.
IsOpenAI
()
{
return
s
.
testOpenAIAccountConnection
(
c
,
account
,
modelID
,
prompt
)
return
s
.
testOpenAIAccountConnection
(
c
,
account
,
modelID
,
prompt
,
normalizeAccountTestMode
(
mode
)
)
}
}
if
account
.
IsGemini
()
{
if
account
.
IsGemini
()
{
...
@@ -416,9 +417,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
...
@@ -416,9 +417,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
}
}
// testOpenAIAccountConnection tests an OpenAI account's connection
// testOpenAIAccountConnection tests an OpenAI account's connection
func
(
s
*
AccountTestService
)
testOpenAIAccountConnection
(
c
*
gin
.
Context
,
account
*
Account
,
modelID
string
,
prompt
string
)
error
{
func
(
s
*
AccountTestService
)
testOpenAIAccountConnection
(
c
*
gin
.
Context
,
account
*
Account
,
modelID
string
,
prompt
string
,
mode
string
)
error
{
ctx
:=
c
.
Request
.
Context
()
ctx
:=
c
.
Request
.
Context
()
_
=
prompt
_
=
prompt
mode
=
normalizeAccountTestMode
(
mode
)
// Default to openai.DefaultTestModel for OpenAI testing
// Default to openai.DefaultTestModel for OpenAI testing
testModelID
:=
modelID
testModelID
:=
modelID
...
@@ -426,14 +428,12 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
...
@@ -426,14 +428,12 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
testModelID
=
openai
.
DefaultTestModel
testModelID
=
openai
.
DefaultTestModel
}
}
// For API Key accounts with model mapping, map the model
// Align test routing with gateway behavior: OpenAI accounts apply normal
if
account
.
Type
==
"apikey"
{
// account model mapping, and compact mode applies compact-only mapping on top.
mapping
:=
account
.
GetModelMapping
()
testModelID
=
account
.
GetMappedModel
(
testModelID
)
if
len
(
mapping
)
>
0
{
if
mode
==
AccountTestModeCompact
{
if
mappedModel
,
exists
:=
mapping
[
testModelID
];
exists
{
testModelID
=
resolveOpenAICompactForwardModel
(
account
,
testModelID
)
testModelID
=
mappedModel
return
s
.
testOpenAICompactConnection
(
c
,
account
,
testModelID
)
}
}
}
}
// Route to image generation test if an image model is selected
// Route to image generation test if an image model is selected
...
@@ -538,6 +538,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
...
@@ -538,6 +538,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if
resp
.
StatusCode
!=
http
.
StatusOK
{
if
resp
.
StatusCode
!=
http
.
StatusOK
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
{
s
.
reconcileOpenAI429State
(
ctx
,
account
,
resp
.
Header
,
body
)
}
// 401 Unauthorized: 标记账号为永久错误
// 401 Unauthorized: 标记账号为永久错误
if
resp
.
StatusCode
==
http
.
StatusUnauthorized
&&
s
.
accountRepo
!=
nil
{
if
resp
.
StatusCode
==
http
.
StatusUnauthorized
&&
s
.
accountRepo
!=
nil
{
errMsg
:=
fmt
.
Sprintf
(
"Authentication failed (401): %s"
,
string
(
body
))
errMsg
:=
fmt
.
Sprintf
(
"Authentication failed (401): %s"
,
string
(
body
))
...
@@ -550,6 +553,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
...
@@ -550,6 +553,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return
s
.
processOpenAIStream
(
c
,
resp
.
Body
)
return
s
.
processOpenAIStream
(
c
,
resp
.
Body
)
}
}
// testOpenAICompactConnection probes /responses/compact and persists the
// resulting capability state on the account.
func
(
s
*
AccountTestService
)
testOpenAICompactConnection
(
c
*
gin
.
Context
,
account
*
Account
,
testModelID
string
)
error
{
ctx
:=
c
.
Request
.
Context
()
authToken
:=
""
apiURL
:=
""
isOAuth
:=
false
chatgptAccountID
:=
""
switch
{
case
account
.
IsOAuth
()
:
isOAuth
=
true
authToken
=
account
.
GetOpenAIAccessToken
()
if
authToken
==
""
{
return
s
.
sendErrorAndEnd
(
c
,
"No access token available"
)
}
apiURL
=
chatgptCodexAPIURL
+
"/compact"
chatgptAccountID
=
account
.
GetChatGPTAccountID
()
case
account
.
Type
==
AccountTypeAPIKey
:
authToken
=
account
.
GetOpenAIApiKey
()
if
authToken
==
""
{
return
s
.
sendErrorAndEnd
(
c
,
"No API key available"
)
}
baseURL
:=
account
.
GetOpenAIBaseURL
()
if
baseURL
==
""
{
baseURL
=
"https://api.openai.com"
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Invalid base URL: %s"
,
err
.
Error
()))
}
apiURL
=
appendOpenAIResponsesRequestPathSuffix
(
buildOpenAIResponsesURL
(
normalizedBaseURL
),
"/compact"
)
default
:
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Unsupported account type: %s"
,
account
.
Type
))
}
c
.
Writer
.
Header
()
.
Set
(
"Content-Type"
,
"text/event-stream"
)
c
.
Writer
.
Header
()
.
Set
(
"Cache-Control"
,
"no-cache"
)
c
.
Writer
.
Header
()
.
Set
(
"Connection"
,
"keep-alive"
)
c
.
Writer
.
Header
()
.
Set
(
"X-Accel-Buffering"
,
"no"
)
c
.
Writer
.
Flush
()
payloadBytes
,
_
:=
json
.
Marshal
(
createOpenAICompactProbePayload
(
testModelID
))
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_start"
,
Model
:
testModelID
})
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
apiURL
,
bytes
.
NewReader
(
payloadBytes
))
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
"Failed to create request"
)
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Accept"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
authToken
)
req
.
Header
.
Set
(
"OpenAI-Beta"
,
"responses=experimental"
)
req
.
Header
.
Set
(
"Originator"
,
"codex_cli_rs"
)
req
.
Header
.
Set
(
"User-Agent"
,
codexCLIUserAgent
)
req
.
Header
.
Set
(
"Version"
,
codexCLIVersion
)
probeSessionID
:=
compactProbeSessionID
(
account
.
ID
)
req
.
Header
.
Set
(
"Session_ID"
,
probeSessionID
)
req
.
Header
.
Set
(
"Conversation_ID"
,
probeSessionID
)
if
isOAuth
{
req
.
Host
=
"chatgpt.com"
if
chatgptAccountID
!=
""
{
req
.
Header
.
Set
(
"chatgpt-account-id"
,
chatgptAccountID
)
}
}
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
DoWithTLS
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
s
.
tlsFPProfileService
.
ResolveTLSProfile
(
account
))
if
err
!=
nil
{
if
s
.
accountRepo
!=
nil
{
updates
:=
buildOpenAICompactProbeExtraUpdates
(
nil
,
nil
,
err
,
time
.
Now
())
_
=
s
.
accountRepo
.
UpdateExtra
(
ctx
,
account
.
ID
,
updates
)
mergeAccountExtra
(
account
,
updates
)
}
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
if
s
.
accountRepo
!=
nil
{
updates
:=
buildOpenAICompactProbeExtraUpdates
(
resp
,
body
,
nil
,
time
.
Now
())
if
codexUpdates
,
err
:=
extractOpenAICodexProbeUpdates
(
resp
);
err
==
nil
&&
len
(
codexUpdates
)
>
0
{
updates
=
mergeExtraUpdates
(
updates
,
codexUpdates
)
}
if
len
(
updates
)
>
0
{
_
=
s
.
accountRepo
.
UpdateExtra
(
ctx
,
account
.
ID
,
updates
)
mergeAccountExtra
(
account
,
updates
)
}
// 探测如返回 429,主动同步限流状态,避免后续短时间内继续选中。
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
{
s
.
reconcileOpenAI429State
(
ctx
,
account
,
resp
.
Header
,
body
)
}
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
if
resp
.
StatusCode
==
http
.
StatusUnauthorized
&&
s
.
accountRepo
!=
nil
{
errMsg
:=
fmt
.
Sprintf
(
"Authentication failed (401): %s"
,
string
(
body
))
_
=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errMsg
)
}
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"API returned %d: %s"
,
resp
.
StatusCode
,
string
(
body
)))
}
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
"Compact probe succeeded"
})
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
}
func
(
s
*
AccountTestService
)
reconcileOpenAI429State
(
ctx
context
.
Context
,
account
*
Account
,
headers
http
.
Header
,
body
[]
byte
)
{
if
s
==
nil
||
s
.
accountRepo
==
nil
||
account
==
nil
{
return
}
var
resetAt
*
time
.
Time
if
calculated
:=
calculateOpenAI429ResetTime
(
headers
);
calculated
!=
nil
{
resetAt
=
calculated
}
else
if
unixTs
:=
parseOpenAIRateLimitResetTime
(
body
);
unixTs
!=
nil
{
t
:=
time
.
Unix
(
*
unixTs
,
0
)
resetAt
=
&
t
}
if
resetAt
==
nil
{
return
}
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
*
resetAt
);
err
!=
nil
{
return
}
now
:=
time
.
Now
()
account
.
RateLimitedAt
=
&
now
account
.
RateLimitResetAt
=
resetAt
if
account
.
Status
==
StatusError
{
if
err
:=
s
.
accountRepo
.
ClearError
(
ctx
,
account
.
ID
);
err
!=
nil
{
return
}
account
.
Status
=
StatusActive
account
.
ErrorMessage
=
""
}
}
// testGeminiAccountConnection tests a Gemini account's connection
// testGeminiAccountConnection tests a Gemini account's connection
func
(
s
*
AccountTestService
)
testGeminiAccountConnection
(
c
*
gin
.
Context
,
account
*
Account
,
modelID
string
,
prompt
string
)
error
{
func
(
s
*
AccountTestService
)
testGeminiAccountConnection
(
c
*
gin
.
Context
,
account
*
Account
,
modelID
string
,
prompt
string
)
error
{
ctx
:=
c
.
Request
.
Context
()
ctx
:=
c
.
Request
.
Context
()
...
@@ -994,13 +1145,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
...
@@ -994,13 +1145,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
// processOpenAIStream processes the SSE stream from OpenAI Responses API
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func
(
s
*
AccountTestService
)
processOpenAIStream
(
c
*
gin
.
Context
,
body
io
.
Reader
)
error
{
func
(
s
*
AccountTestService
)
processOpenAIStream
(
c
*
gin
.
Context
,
body
io
.
Reader
)
error
{
reader
:=
bufio
.
NewReader
(
body
)
reader
:=
bufio
.
NewReader
(
body
)
seenCompleted
:=
false
for
{
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
line
,
err
:=
reader
.
ReadString
(
'\n'
)
if
err
!=
nil
{
if
err
!=
nil
{
if
err
==
io
.
EOF
{
if
err
==
io
.
EOF
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
if
seenCompleted
{
return
nil
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
}
return
s
.
sendErrorAndEnd
(
c
,
"Stream ended before response.completed"
)
}
}
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Stream read error: %s"
,
err
.
Error
()))
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Stream read error: %s"
,
err
.
Error
()))
}
}
...
@@ -1012,8 +1167,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
...
@@ -1012,8 +1167,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
jsonStr
:=
sseDataPrefix
.
ReplaceAllString
(
line
,
""
)
jsonStr
:=
sseDataPrefix
.
ReplaceAllString
(
line
,
""
)
if
jsonStr
==
"[DONE]"
{
if
jsonStr
==
"[DONE]"
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
if
seenCompleted
{
return
nil
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
}
return
s
.
sendErrorAndEnd
(
c
,
"Stream ended before response.completed"
)
}
}
var
data
map
[
string
]
any
var
data
map
[
string
]
any
...
@@ -1029,9 +1187,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
...
@@ -1029,9 +1187,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
if
delta
,
ok
:=
data
[
"delta"
]
.
(
string
);
ok
&&
delta
!=
""
{
if
delta
,
ok
:=
data
[
"delta"
]
.
(
string
);
ok
&&
delta
!=
""
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
delta
})
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
delta
})
}
}
case
"response.completed"
:
case
"response.completed"
,
"response.done"
:
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
return
nil
case
"response.failed"
:
errorMsg
:=
"OpenAI response failed"
if
responseData
,
ok
:=
data
[
"response"
]
.
(
map
[
string
]
any
);
ok
{
if
errData
,
ok
:=
responseData
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
msg
,
ok
:=
errData
[
"message"
]
.
(
string
);
ok
&&
msg
!=
""
{
errorMsg
=
msg
}
}
}
return
s
.
sendErrorAndEnd
(
c
,
errorMsg
)
case
"error"
:
case
"error"
:
errorMsg
:=
"Unknown error"
errorMsg
:=
"Unknown error"
if
errData
,
ok
:=
data
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
if
errData
,
ok
:=
data
[
"error"
]
.
(
map
[
string
]
any
);
ok
{
...
@@ -1261,7 +1429,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
...
@@ -1261,7 +1429,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx
,
_
:=
gin
.
CreateTestContext
(
w
)
ginCtx
,
_
:=
gin
.
CreateTestContext
(
w
)
ginCtx
.
Request
=
(
&
http
.
Request
{})
.
WithContext
(
ctx
)
ginCtx
.
Request
=
(
&
http
.
Request
{})
.
WithContext
(
ctx
)
testErr
:=
s
.
TestAccountConnection
(
ginCtx
,
accountID
,
modelID
,
""
)
testErr
:=
s
.
TestAccountConnection
(
ginCtx
,
accountID
,
modelID
,
""
,
AccountTestModeDefault
)
finishedAt
:=
time
.
Now
()
finishedAt
:=
time
.
Now
()
body
:=
w
.
Body
.
String
()
body
:=
w
.
Body
.
String
()
...
...
backend/internal/service/account_test_service_openai_compact_test.go
0 → 100644
View file @
3b7a5fff
package
service
import
(
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func
TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersistsSupport
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
updateCalls
:=
make
(
chan
map
[
string
]
any
,
1
)
account
:=
Account
{
ID
:
1
,
Name
:
"openai-oauth"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"oauth-token"
,
"chatgpt_account_id"
:
"chatgpt-acc"
,
},
}
repo
:=
&
snapshotUpdateAccountRepo
{
stubOpenAIAccountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
updateExtraCalls
:
updateCalls
,
}
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
},
"x-request-id"
:
[]
string
{
"rid-probe"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"cmp_probe","status":"completed"}`
)),
}}
svc
:=
&
AccountTestService
{
accountRepo
:
repo
,
httpUpstream
:
upstream
,
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/1/test"
,
bytes
.
NewReader
(
nil
))
err
:=
svc
.
TestAccountConnection
(
c
,
account
.
ID
,
"gpt-5.4"
,
""
,
AccountTestModeCompact
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
chatgptCodexAPIURL
+
"/compact"
,
upstream
.
lastReq
.
URL
.
String
())
require
.
Equal
(
t
,
"chatgpt.com"
,
upstream
.
lastReq
.
Host
)
require
.
Equal
(
t
,
"application/json"
,
upstream
.
lastReq
.
Header
.
Get
(
"Accept"
))
require
.
Equal
(
t
,
codexCLIVersion
,
upstream
.
lastReq
.
Header
.
Get
(
"Version"
))
require
.
NotEmpty
(
t
,
upstream
.
lastReq
.
Header
.
Get
(
"Session_Id"
))
require
.
Equal
(
t
,
codexCLIUserAgent
,
upstream
.
lastReq
.
Header
.
Get
(
"User-Agent"
))
require
.
Equal
(
t
,
"chatgpt-acc"
,
upstream
.
lastReq
.
Header
.
Get
(
"chatgpt-account-id"
))
require
.
Equal
(
t
,
"gpt-5.4"
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"model"
)
.
String
())
updates
:=
<-
updateCalls
require
.
Equal
(
t
,
true
,
updates
[
"openai_compact_supported"
])
require
.
Equal
(
t
,
http
.
StatusOK
,
updates
[
"openai_compact_last_status"
])
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
`"type":"test_complete"`
)
}
func
TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsupported
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
updateCalls
:=
make
(
chan
map
[
string
]
any
,
1
)
account
:=
Account
{
ID
:
2
,
Name
:
"openai-oauth"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"oauth-token"
,
"chatgpt_account_id"
:
"chatgpt-acc"
,
},
}
repo
:=
&
snapshotUpdateAccountRepo
{
stubOpenAIAccountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
updateExtraCalls
:
updateCalls
,
}
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusNotFound
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`404 page not found`
)),
}}
svc
:=
&
AccountTestService
{
accountRepo
:
repo
,
httpUpstream
:
upstream
,
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/2/test"
,
bytes
.
NewReader
(
nil
))
err
:=
svc
.
TestAccountConnection
(
c
,
account
.
ID
,
"gpt-5.4"
,
""
,
AccountTestModeCompact
)
require
.
Error
(
t
,
err
)
updates
:=
<-
updateCalls
require
.
Equal
(
t
,
false
,
updates
[
"openai_compact_supported"
])
require
.
Equal
(
t
,
http
.
StatusNotFound
,
updates
[
"openai_compact_last_status"
])
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
`"type":"error"`
)
}
func
TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompactPath
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
updateCalls
:=
make
(
chan
map
[
string
]
any
,
1
)
account
:=
Account
{
ID
:
3
,
Name
:
"openai-apikey"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
"base_url"
:
"https://example.com/v1"
,
"compact_model_mapping"
:
map
[
string
]
any
{
"gpt-5.4"
:
"gpt-5.4-openai-compact"
},
},
}
repo
:=
&
snapshotUpdateAccountRepo
{
stubOpenAIAccountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
updateExtraCalls
:
updateCalls
,
}
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"cmp_probe_apikey","status":"completed"}`
)),
}}
svc
:=
&
AccountTestService
{
accountRepo
:
repo
,
httpUpstream
:
upstream
,
cfg
:
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
}}},
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/3/test"
,
bytes
.
NewReader
(
nil
))
err
:=
svc
.
TestAccountConnection
(
c
,
account
.
ID
,
"gpt-5.4"
,
""
,
AccountTestModeCompact
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"https://example.com/v1/responses/compact"
,
upstream
.
lastReq
.
URL
.
String
())
require
.
Equal
(
t
,
"gpt-5.4-openai-compact"
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"model"
)
.
String
())
updates
:=
<-
updateCalls
require
.
Equal
(
t
,
true
,
updates
[
"openai_compact_supported"
])
}
func
TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBaseURLUsesV1Path
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
updateCalls
:=
make
(
chan
map
[
string
]
any
,
1
)
account
:=
Account
{
ID
:
4
,
Name
:
"openai-apikey-default"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"api_key"
:
"sk-test"
,
},
}
repo
:=
&
snapshotUpdateAccountRepo
{
stubOpenAIAccountRepo
:
stubOpenAIAccountRepo
{
accounts
:
[]
Account
{
account
}},
updateExtraCalls
:
updateCalls
,
}
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"id":"cmp_probe_apikey_default","status":"completed"}`
)),
}}
svc
:=
&
AccountTestService
{
accountRepo
:
repo
,
httpUpstream
:
upstream
,
cfg
:
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
}}},
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/4/test"
,
bytes
.
NewReader
(
nil
))
err
:=
svc
.
TestAccountConnection
(
c
,
account
.
ID
,
"gpt-5.4"
,
""
,
AccountTestModeCompact
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"https://api.openai.com/v1/responses/compact"
,
upstream
.
lastReq
.
URL
.
String
())
<-
updateCalls
}
backend/internal/service/account_test_service_openai_test.go
View file @
3b7a5fff
...
@@ -61,9 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
...
@@ -61,9 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
type
openAIAccountTestRepo
struct
{
type
openAIAccountTestRepo
struct
{
mockAccountRepoForGemini
mockAccountRepoForGemini
updatedExtra
map
[
string
]
any
updatedExtra
map
[
string
]
any
rateLimitedID
int64
rateLimitedID
int64
rateLimitedAt
*
time
.
Time
rateLimitedAt
*
time
.
Time
clearedErrorID
int64
setErrorID
int64
setErrorMsg
string
}
}
func
(
r
*
openAIAccountTestRepo
)
UpdateExtra
(
_
context
.
Context
,
_
int64
,
updates
map
[
string
]
any
)
error
{
func
(
r
*
openAIAccountTestRepo
)
UpdateExtra
(
_
context
.
Context
,
_
int64
,
updates
map
[
string
]
any
)
error
{
...
@@ -77,6 +80,17 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
...
@@ -77,6 +80,17 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
return
nil
return
nil
}
}
func
(
r
*
openAIAccountTestRepo
)
ClearError
(
_
context
.
Context
,
id
int64
)
error
{
r
.
clearedErrorID
=
id
return
nil
}
func
(
r
*
openAIAccountTestRepo
)
SetError
(
_
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
r
.
setErrorID
=
id
r
.
setErrorMsg
=
errorMsg
return
nil
}
func
TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders
(
t
*
testing
.
T
)
{
func
TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
ctx
,
recorder
:=
newTestContext
()
ctx
,
recorder
:=
newTestContext
()
...
@@ -103,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
...
@@ -103,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
},
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
},
}
}
err
:=
svc
.
testOpenAIAccountConnection
(
ctx
,
account
,
"gpt-5.4"
,
""
)
err
:=
svc
.
testOpenAIAccountConnection
(
ctx
,
account
,
"gpt-5.4"
,
""
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
repo
.
updatedExtra
)
require
.
NotEmpty
(
t
,
repo
.
updatedExtra
)
require
.
Equal
(
t
,
42.0
,
repo
.
updatedExtra
[
"codex_5h_used_percent"
])
require
.
Equal
(
t
,
42.0
,
repo
.
updatedExtra
[
"codex_5h_used_percent"
])
...
@@ -111,11 +125,36 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
...
@@ -111,11 +125,36 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require
.
Contains
(
t
,
recorder
.
Body
.
String
(),
"test_complete"
)
require
.
Contains
(
t
,
recorder
.
Body
.
String
(),
"test_complete"
)
}
}
func
TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit
(
t
*
testing
.
T
)
{
func
TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
ctx
,
recorder
:=
newTestContext
()
resp
:=
newJSONResponse
(
http
.
StatusOK
,
""
)
resp
.
Body
=
io
.
NopCloser
(
strings
.
NewReader
(
`data: {"type":"response.output_text.delta","delta":"hi"}
`
))
upstream
:=
&
queuedHTTPUpstream
{
responses
:
[]
*
http
.
Response
{
resp
}}
svc
:=
&
AccountTestService
{
httpUpstream
:
upstream
}
account
:=
&
Account
{
ID
:
90
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
},
}
err
:=
svc
.
testOpenAIAccountConnection
(
ctx
,
account
,
"gpt-5.4"
,
""
,
""
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
recorder
.
Body
.
String
(),
"response.completed"
)
require
.
NotContains
(
t
,
recorder
.
Body
.
String
(),
`"success":true`
)
}
func
TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
ctx
,
_
:=
newTestContext
()
ctx
,
_
:=
newTestContext
()
resp
:=
newJSONResponse
(
http
.
StatusTooManyRequests
,
`{"error":{"type":"usage_limit_reached","message":"limit reached"}}`
)
resp
:=
newJSONResponse
(
http
.
StatusTooManyRequests
,
`{"error":{"type":"usage_limit_reached","message":"limit reached"
,"resets_at":1777283883
}}`
)
resp
.
Header
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
resp
.
Header
.
Set
(
"x-codex-primary-used-percent"
,
"100"
)
resp
.
Header
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"604800"
)
resp
.
Header
.
Set
(
"x-codex-primary-reset-after-seconds"
,
"604800"
)
resp
.
Header
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
resp
.
Header
.
Set
(
"x-codex-primary-window-minutes"
,
"10080"
)
...
@@ -130,15 +169,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
...
@@ -130,15 +169,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
ID
:
88
,
ID
:
88
,
Platform
:
PlatformOpenAI
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Type
:
AccountTypeOAuth
,
Status
:
StatusError
,
Concurrency
:
1
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
},
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
},
}
}
err
:=
svc
.
testOpenAIAccountConnection
(
ctx
,
account
,
"gpt-5.4"
,
""
)
err
:=
svc
.
testOpenAIAccountConnection
(
ctx
,
account
,
"gpt-5.4"
,
""
,
""
)
require
.
Error
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
NotEmpty
(
t
,
repo
.
updatedExtra
)
require
.
NotEmpty
(
t
,
repo
.
updatedExtra
)
require
.
Equal
(
t
,
100.0
,
repo
.
updatedExtra
[
"codex_5h_used_percent"
])
require
.
Equal
(
t
,
100.0
,
repo
.
updatedExtra
[
"codex_5h_used_percent"
])
require
.
Equal
(
t
,
account
.
ID
,
repo
.
rateLimitedID
)
require
.
NotNil
(
t
,
repo
.
rateLimitedAt
)
require
.
Equal
(
t
,
account
.
ID
,
repo
.
clearedErrorID
)
require
.
Equal
(
t
,
StatusActive
,
account
.
Status
)
require
.
Empty
(
t
,
account
.
ErrorMessage
)
require
.
NotNil
(
t
,
account
.
RateLimitResetAt
)
}
func
TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
ctx
,
_
:=
newTestContext
()
resp
:=
newJSONResponse
(
http
.
StatusTooManyRequests
,
`{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`
)
repo
:=
&
openAIAccountTestRepo
{}
upstream
:=
&
queuedHTTPUpstream
{
responses
:
[]
*
http
.
Response
{
resp
}}
svc
:=
&
AccountTestService
{
accountRepo
:
repo
,
httpUpstream
:
upstream
}
account
:=
&
Account
{
ID
:
77
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusError
,
ErrorMessage
:
"Access forbidden (403): account may be suspended or lack permissions"
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
},
}
err
:=
svc
.
testOpenAIAccountConnection
(
ctx
,
account
,
"gpt-5.4"
,
""
,
""
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
account
.
ID
,
repo
.
rateLimitedID
)
require
.
NotNil
(
t
,
repo
.
rateLimitedAt
)
require
.
Equal
(
t
,
account
.
ID
,
repo
.
clearedErrorID
)
require
.
Equal
(
t
,
StatusActive
,
account
.
Status
)
require
.
Empty
(
t
,
account
.
ErrorMessage
)
require
.
NotNil
(
t
,
account
.
RateLimitResetAt
)
require
.
Empty
(
t
,
repo
.
updatedExtra
)
}
func
TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
ctx
,
_
:=
newTestContext
()
resp
:=
newJSONResponse
(
http
.
StatusTooManyRequests
,
`{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`
)
repo
:=
&
openAIAccountTestRepo
{}
upstream
:=
&
queuedHTTPUpstream
{
responses
:
[]
*
http
.
Response
{
resp
}}
svc
:=
&
AccountTestService
{
accountRepo
:
repo
,
httpUpstream
:
upstream
}
account
:=
&
Account
{
ID
:
78
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
},
}
err
:=
svc
.
testOpenAIAccountConnection
(
ctx
,
account
,
"gpt-5.4"
,
""
,
""
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
account
.
ID
,
repo
.
rateLimitedID
)
require
.
NotNil
(
t
,
repo
.
rateLimitedAt
)
require
.
Zero
(
t
,
repo
.
clearedErrorID
)
require
.
Equal
(
t
,
StatusActive
,
account
.
Status
)
require
.
NotNil
(
t
,
account
.
RateLimitResetAt
)
}
func
TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
ctx
,
_
:=
newTestContext
()
resp
:=
newJSONResponse
(
http
.
StatusTooManyRequests
,
`{"error":{"type":"usage_limit_reached","message":"limit reached"}}`
)
repo
:=
&
openAIAccountTestRepo
{}
upstream
:=
&
queuedHTTPUpstream
{
responses
:
[]
*
http
.
Response
{
resp
}}
svc
:=
&
AccountTestService
{
accountRepo
:
repo
,
httpUpstream
:
upstream
}
account
:=
&
Account
{
ID
:
79
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusError
,
ErrorMessage
:
"stale 403"
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
},
}
err
:=
svc
.
testOpenAIAccountConnection
(
ctx
,
account
,
"gpt-5.4"
,
""
,
""
)
require
.
Error
(
t
,
err
)
require
.
Zero
(
t
,
repo
.
rateLimitedID
)
require
.
Zero
(
t
,
repo
.
rateLimitedID
)
require
.
Nil
(
t
,
repo
.
rateLimitedAt
)
require
.
Nil
(
t
,
repo
.
rateLimitedAt
)
require
.
Zero
(
t
,
repo
.
clearedErrorID
)
require
.
Equal
(
t
,
StatusError
,
account
.
Status
)
require
.
Equal
(
t
,
"stale 403"
,
account
.
ErrorMessage
)
require
.
Nil
(
t
,
account
.
RateLimitResetAt
)
}
func
TestAccountTestService_OpenAI401SetsPermanentErrorOnly
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
ctx
,
_
:=
newTestContext
()
resp
:=
newJSONResponse
(
http
.
StatusUnauthorized
,
`{"error":"bad token"}`
)
repo
:=
&
openAIAccountTestRepo
{}
upstream
:=
&
queuedHTTPUpstream
{
responses
:
[]
*
http
.
Response
{
resp
}}
svc
:=
&
AccountTestService
{
accountRepo
:
repo
,
httpUpstream
:
upstream
}
account
:=
&
Account
{
ID
:
80
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Status
:
StatusActive
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
},
}
err
:=
svc
.
testOpenAIAccountConnection
(
ctx
,
account
,
"gpt-5.4"
,
""
,
""
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
account
.
ID
,
repo
.
setErrorID
)
require
.
Contains
(
t
,
repo
.
setErrorMsg
,
"Authentication failed (401)"
)
require
.
Zero
(
t
,
repo
.
rateLimitedID
)
require
.
Zero
(
t
,
repo
.
clearedErrorID
)
require
.
Nil
(
t
,
account
.
RateLimitResetAt
)
require
.
Nil
(
t
,
account
.
RateLimitResetAt
)
}
}
backend/internal/service/account_usage_service.go
View file @
3b7a5fff
...
@@ -110,7 +110,7 @@ const (
...
@@ -110,7 +110,7 @@ const (
apiQueryMaxJitter
=
800
*
time
.
Millisecond
// 用量查询最大随机延迟
apiQueryMaxJitter
=
800
*
time
.
Millisecond
// 用量查询最大随机延迟
windowStatsCacheTTL
=
1
*
time
.
Minute
windowStatsCacheTTL
=
1
*
time
.
Minute
openAIProbeCacheTTL
=
10
*
time
.
Minute
openAIProbeCacheTTL
=
10
*
time
.
Minute
openAICodexProbeVersion
=
"0.1
04
.0"
openAICodexProbeVersion
=
"0.1
25
.0"
)
)
// UsageCache 封装账户使用量相关的缓存
// UsageCache 封装账户使用量相关的缓存
...
...
backend/internal/service/affiliate_service.go
0 → 100644
View file @
3b7a5fff
package
service
import
(
"context"
"errors"
"math"
"strings"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var
(
ErrAffiliateProfileNotFound
=
infraerrors
.
NotFound
(
"AFFILIATE_PROFILE_NOT_FOUND"
,
"affiliate profile not found"
)
ErrAffiliateCodeInvalid
=
infraerrors
.
BadRequest
(
"AFFILIATE_CODE_INVALID"
,
"invalid affiliate code"
)
ErrAffiliateCodeTaken
=
infraerrors
.
Conflict
(
"AFFILIATE_CODE_TAKEN"
,
"affiliate code already in use"
)
ErrAffiliateAlreadyBound
=
infraerrors
.
Conflict
(
"AFFILIATE_ALREADY_BOUND"
,
"affiliate inviter already bound"
)
ErrAffiliateQuotaEmpty
=
infraerrors
.
BadRequest
(
"AFFILIATE_QUOTA_EMPTY"
,
"no affiliate quota available to transfer"
)
)
const
(
affiliateInviteesLimit
=
100
// AffiliateCodeMinLength / AffiliateCodeMaxLength bound both system-generated
// 12-char codes and admin-customized codes (e.g. "VIP2026").
AffiliateCodeMinLength
=
4
AffiliateCodeMaxLength
=
32
)
// affiliateCodeValidChar accepts uppercase letters, digits, underscore and dash.
// All input passes through strings.ToUpper before validation, so lowercase from
// users is normalized — admins may supply mixed case in their UI.
var
affiliateCodeValidChar
=
func
()
[
256
]
bool
{
var
tbl
[
256
]
bool
for
c
:=
byte
(
'A'
);
c
<=
'Z'
;
c
++
{
tbl
[
c
]
=
true
}
for
c
:=
byte
(
'0'
);
c
<=
'9'
;
c
++
{
tbl
[
c
]
=
true
}
tbl
[
'_'
]
=
true
tbl
[
'-'
]
=
true
return
tbl
}()
// isValidAffiliateCodeFormat validates code format for both binding (user input)
// and admin updates. Caller is expected to upper-case the input first.
func
isValidAffiliateCodeFormat
(
code
string
)
bool
{
if
len
(
code
)
<
AffiliateCodeMinLength
||
len
(
code
)
>
AffiliateCodeMaxLength
{
return
false
}
for
i
:=
0
;
i
<
len
(
code
);
i
++
{
if
!
affiliateCodeValidChar
[
code
[
i
]]
{
return
false
}
}
return
true
}
type
AffiliateSummary
struct
{
UserID
int64
`json:"user_id"`
AffCode
string
`json:"aff_code"`
AffCodeCustom
bool
`json:"aff_code_custom"`
AffRebateRatePercent
*
float64
`json:"aff_rebate_rate_percent,omitempty"`
InviterID
*
int64
`json:"inviter_id,omitempty"`
AffCount
int
`json:"aff_count"`
AffQuota
float64
`json:"aff_quota"`
AffFrozenQuota
float64
`json:"aff_frozen_quota"`
AffHistoryQuota
float64
`json:"aff_history_quota"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
}
type
AffiliateInvitee
struct
{
UserID
int64
`json:"user_id"`
Email
string
`json:"email"`
Username
string
`json:"username"`
CreatedAt
*
time
.
Time
`json:"created_at,omitempty"`
TotalRebate
float64
`json:"total_rebate"`
}
type
AffiliateDetail
struct
{
UserID
int64
`json:"user_id"`
AffCode
string
`json:"aff_code"`
InviterID
*
int64
`json:"inviter_id,omitempty"`
AffCount
int
`json:"aff_count"`
AffQuota
float64
`json:"aff_quota"`
AffFrozenQuota
float64
`json:"aff_frozen_quota"`
AffHistoryQuota
float64
`json:"aff_history_quota"`
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
// 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
EffectiveRebateRatePercent
float64
`json:"effective_rebate_rate_percent"`
Invitees
[]
AffiliateInvitee
`json:"invitees"`
}
type
AffiliateRepository
interface
{
EnsureUserAffiliate
(
ctx
context
.
Context
,
userID
int64
)
(
*
AffiliateSummary
,
error
)
GetAffiliateByCode
(
ctx
context
.
Context
,
code
string
)
(
*
AffiliateSummary
,
error
)
BindInviter
(
ctx
context
.
Context
,
userID
,
inviterID
int64
)
(
bool
,
error
)
AccrueQuota
(
ctx
context
.
Context
,
inviterID
,
inviteeUserID
int64
,
amount
float64
,
freezeHours
int
)
(
bool
,
error
)
GetAccruedRebateFromInvitee
(
ctx
context
.
Context
,
inviterID
,
inviteeUserID
int64
)
(
float64
,
error
)
ThawFrozenQuota
(
ctx
context
.
Context
,
userID
int64
)
(
float64
,
error
)
TransferQuotaToBalance
(
ctx
context
.
Context
,
userID
int64
)
(
float64
,
float64
,
error
)
ListInvitees
(
ctx
context
.
Context
,
inviterID
int64
,
limit
int
)
([]
AffiliateInvitee
,
error
)
// 管理端:用户级专属配置
UpdateUserAffCode
(
ctx
context
.
Context
,
userID
int64
,
newCode
string
)
error
ResetUserAffCode
(
ctx
context
.
Context
,
userID
int64
)
(
string
,
error
)
SetUserRebateRate
(
ctx
context
.
Context
,
userID
int64
,
ratePercent
*
float64
)
error
BatchSetUserRebateRate
(
ctx
context
.
Context
,
userIDs
[]
int64
,
ratePercent
*
float64
)
error
ListUsersWithCustomSettings
(
ctx
context
.
Context
,
filter
AffiliateAdminFilter
)
([]
AffiliateAdminEntry
,
int64
,
error
)
}
// AffiliateAdminFilter 列表筛选条件
type
AffiliateAdminFilter
struct
{
Search
string
Page
int
PageSize
int
}
// AffiliateAdminEntry 专属用户列表条目
type
AffiliateAdminEntry
struct
{
UserID
int64
`json:"user_id"`
Email
string
`json:"email"`
Username
string
`json:"username"`
AffCode
string
`json:"aff_code"`
AffCodeCustom
bool
`json:"aff_code_custom"`
AffRebateRatePercent
*
float64
`json:"aff_rebate_rate_percent,omitempty"`
AffCount
int
`json:"aff_count"`
}
type
AffiliateService
struct
{
repo
AffiliateRepository
settingService
*
SettingService
authCacheInvalidator
APIKeyAuthCacheInvalidator
billingCacheService
*
BillingCacheService
}
func
NewAffiliateService
(
repo
AffiliateRepository
,
settingService
*
SettingService
,
authCacheInvalidator
APIKeyAuthCacheInvalidator
,
billingCacheService
*
BillingCacheService
)
*
AffiliateService
{
return
&
AffiliateService
{
repo
:
repo
,
settingService
:
settingService
,
authCacheInvalidator
:
authCacheInvalidator
,
billingCacheService
:
billingCacheService
,
}
}
// IsEnabled reports whether the affiliate (邀请返利) feature is turned on.
func
(
s
*
AffiliateService
)
IsEnabled
(
ctx
context
.
Context
)
bool
{
if
s
==
nil
||
s
.
settingService
==
nil
{
return
AffiliateEnabledDefault
}
return
s
.
settingService
.
IsAffiliateEnabled
(
ctx
)
}
func
(
s
*
AffiliateService
)
EnsureUserAffiliate
(
ctx
context
.
Context
,
userID
int64
)
(
*
AffiliateSummary
,
error
)
{
if
userID
<=
0
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_USER"
,
"invalid user"
)
}
if
s
==
nil
||
s
.
repo
==
nil
{
return
nil
,
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"affiliate service unavailable"
)
}
return
s
.
repo
.
EnsureUserAffiliate
(
ctx
,
userID
)
}
func
(
s
*
AffiliateService
)
GetAffiliateDetail
(
ctx
context
.
Context
,
userID
int64
)
(
*
AffiliateDetail
,
error
)
{
// Lazy thaw: move any matured frozen quota to available before reading.
if
s
!=
nil
&&
s
.
repo
!=
nil
{
// best-effort: thaw failure is non-fatal
_
,
_
=
s
.
repo
.
ThawFrozenQuota
(
ctx
,
userID
)
}
summary
,
err
:=
s
.
EnsureUserAffiliate
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
err
}
invitees
,
err
:=
s
.
listInvitees
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
err
}
return
&
AffiliateDetail
{
UserID
:
summary
.
UserID
,
AffCode
:
summary
.
AffCode
,
InviterID
:
summary
.
InviterID
,
AffCount
:
summary
.
AffCount
,
AffQuota
:
summary
.
AffQuota
,
AffFrozenQuota
:
summary
.
AffFrozenQuota
,
AffHistoryQuota
:
summary
.
AffHistoryQuota
,
EffectiveRebateRatePercent
:
s
.
resolveRebateRatePercent
(
ctx
,
summary
),
Invitees
:
invitees
,
},
nil
}
func
(
s
*
AffiliateService
)
BindInviterByCode
(
ctx
context
.
Context
,
userID
int64
,
rawCode
string
)
error
{
code
:=
strings
.
ToUpper
(
strings
.
TrimSpace
(
rawCode
))
if
code
==
""
{
return
nil
}
if
s
==
nil
||
s
.
repo
==
nil
{
return
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"affiliate service unavailable"
)
}
// 总开关关闭时,注册阶段静默忽略 aff 参数(不报错,避免阻断注册流程)
if
!
s
.
IsEnabled
(
ctx
)
{
return
nil
}
if
!
isValidAffiliateCodeFormat
(
code
)
{
return
ErrAffiliateCodeInvalid
}
selfSummary
,
err
:=
s
.
repo
.
EnsureUserAffiliate
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
if
selfSummary
.
InviterID
!=
nil
{
return
nil
}
inviterSummary
,
err
:=
s
.
repo
.
GetAffiliateByCode
(
ctx
,
code
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrAffiliateProfileNotFound
)
{
return
ErrAffiliateCodeInvalid
}
return
err
}
if
inviterSummary
==
nil
||
inviterSummary
.
UserID
<=
0
||
inviterSummary
.
UserID
==
userID
{
return
ErrAffiliateCodeInvalid
}
bound
,
err
:=
s
.
repo
.
BindInviter
(
ctx
,
userID
,
inviterSummary
.
UserID
)
if
err
!=
nil
{
return
err
}
if
!
bound
{
return
ErrAffiliateAlreadyBound
}
return
nil
}
func
(
s
*
AffiliateService
)
AccrueInviteRebate
(
ctx
context
.
Context
,
inviteeUserID
int64
,
baseRechargeAmount
float64
)
(
float64
,
error
)
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
0
,
nil
}
if
inviteeUserID
<=
0
||
baseRechargeAmount
<=
0
||
math
.
IsNaN
(
baseRechargeAmount
)
||
math
.
IsInf
(
baseRechargeAmount
,
0
)
{
return
0
,
nil
}
// 总开关关闭时,新充值不再产生返利
if
!
s
.
IsEnabled
(
ctx
)
{
return
0
,
nil
}
inviteeSummary
,
err
:=
s
.
repo
.
EnsureUserAffiliate
(
ctx
,
inviteeUserID
)
if
err
!=
nil
{
return
0
,
err
}
if
inviteeSummary
.
InviterID
==
nil
||
*
inviteeSummary
.
InviterID
<=
0
{
return
0
,
nil
}
// 加载邀请人 profile,优先使用专属比例(覆盖全局)
inviterSummary
,
err
:=
s
.
repo
.
EnsureUserAffiliate
(
ctx
,
*
inviteeSummary
.
InviterID
)
if
err
!=
nil
{
return
0
,
err
}
// 有效期检查:超过返利有效期后不再产生返利
if
s
.
settingService
!=
nil
{
if
durationDays
:=
s
.
settingService
.
GetAffiliateRebateDurationDays
(
ctx
);
durationDays
>
0
{
if
time
.
Now
()
.
After
(
inviteeSummary
.
CreatedAt
.
AddDate
(
0
,
0
,
durationDays
))
{
return
0
,
nil
}
}
}
rebateRatePercent
:=
s
.
resolveRebateRatePercent
(
ctx
,
inviterSummary
)
rebate
:=
roundTo
(
baseRechargeAmount
*
(
rebateRatePercent
/
100
),
8
)
if
rebate
<=
0
{
return
0
,
nil
}
// 单人上限检查:精确截断到剩余额度
if
s
.
settingService
!=
nil
{
if
perInviteeCap
:=
s
.
settingService
.
GetAffiliateRebatePerInviteeCap
(
ctx
);
perInviteeCap
>
0
{
existing
,
err
:=
s
.
repo
.
GetAccruedRebateFromInvitee
(
ctx
,
*
inviteeSummary
.
InviterID
,
inviteeUserID
)
if
err
!=
nil
{
return
0
,
err
}
if
existing
>=
perInviteeCap
{
return
0
,
nil
}
if
remaining
:=
perInviteeCap
-
existing
;
rebate
>
remaining
{
rebate
=
roundTo
(
remaining
,
8
)
}
}
}
var
freezeHours
int
if
s
.
settingService
!=
nil
{
freezeHours
=
s
.
settingService
.
GetAffiliateRebateFreezeHours
(
ctx
)
}
applied
,
err
:=
s
.
repo
.
AccrueQuota
(
ctx
,
*
inviteeSummary
.
InviterID
,
inviteeUserID
,
rebate
,
freezeHours
)
if
err
!=
nil
{
return
0
,
err
}
if
!
applied
{
return
0
,
nil
}
return
rebate
,
nil
}
// resolveRebateRatePercent returns the inviter's exclusive rate when set,
// otherwise the global setting value (clamped to [Min, Max]).
func
(
s
*
AffiliateService
)
resolveRebateRatePercent
(
ctx
context
.
Context
,
inviter
*
AffiliateSummary
)
float64
{
if
inviter
!=
nil
&&
inviter
.
AffRebateRatePercent
!=
nil
{
v
:=
*
inviter
.
AffRebateRatePercent
if
math
.
IsNaN
(
v
)
||
math
.
IsInf
(
v
,
0
)
{
return
s
.
globalRebateRatePercent
(
ctx
)
}
return
clampAffiliateRebateRate
(
v
)
}
return
s
.
globalRebateRatePercent
(
ctx
)
}
// globalRebateRatePercent reads the system-wide rebate rate via SettingService,
// returning the documented default when SettingService is unavailable.
func
(
s
*
AffiliateService
)
globalRebateRatePercent
(
ctx
context
.
Context
)
float64
{
if
s
==
nil
||
s
.
settingService
==
nil
{
return
AffiliateRebateRateDefault
}
return
s
.
settingService
.
GetAffiliateRebateRatePercent
(
ctx
)
}
func
(
s
*
AffiliateService
)
TransferAffiliateQuota
(
ctx
context
.
Context
,
userID
int64
)
(
float64
,
float64
,
error
)
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
0
,
0
,
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"affiliate service unavailable"
)
}
transferred
,
balance
,
err
:=
s
.
repo
.
TransferQuotaToBalance
(
ctx
,
userID
)
if
err
!=
nil
{
return
0
,
0
,
err
}
if
transferred
>
0
{
s
.
invalidateAffiliateCaches
(
ctx
,
userID
)
}
return
transferred
,
balance
,
nil
}
func
(
s
*
AffiliateService
)
listInvitees
(
ctx
context
.
Context
,
inviterID
int64
)
([]
AffiliateInvitee
,
error
)
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
nil
,
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"affiliate service unavailable"
)
}
invitees
,
err
:=
s
.
repo
.
ListInvitees
(
ctx
,
inviterID
,
affiliateInviteesLimit
)
if
err
!=
nil
{
return
nil
,
err
}
for
i
:=
range
invitees
{
invitees
[
i
]
.
Email
=
maskEmail
(
invitees
[
i
]
.
Email
)
}
return
invitees
,
nil
}
func
roundTo
(
v
float64
,
scale
int
)
float64
{
factor
:=
math
.
Pow10
(
scale
)
return
math
.
Round
(
v
*
factor
)
/
factor
}
func
maskEmail
(
email
string
)
string
{
email
=
strings
.
TrimSpace
(
email
)
if
email
==
""
{
return
""
}
at
:=
strings
.
Index
(
email
,
"@"
)
if
at
<=
0
||
at
>=
len
(
email
)
-
1
{
return
"***"
}
local
:=
email
[
:
at
]
domain
:=
email
[
at
+
1
:
]
dot
:=
strings
.
LastIndex
(
domain
,
"."
)
maskedLocal
:=
maskSegment
(
local
)
if
dot
<=
0
||
dot
>=
len
(
domain
)
-
1
{
return
maskedLocal
+
"@"
+
maskSegment
(
domain
)
}
domainName
:=
domain
[
:
dot
]
tld
:=
domain
[
dot
:
]
return
maskedLocal
+
"@"
+
maskSegment
(
domainName
)
+
tld
}
func
maskSegment
(
s
string
)
string
{
r
:=
[]
rune
(
s
)
if
len
(
r
)
==
0
{
return
"***"
}
if
len
(
r
)
==
1
{
return
string
(
r
[
0
])
+
"***"
}
return
string
(
r
[
0
])
+
"***"
}
func
(
s
*
AffiliateService
)
invalidateAffiliateCaches
(
ctx
context
.
Context
,
userID
int64
)
{
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByUserID
(
ctx
,
userID
)
}
if
s
.
billingCacheService
!=
nil
{
if
err
:=
s
.
billingCacheService
.
InvalidateUserBalance
(
ctx
,
userID
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.affiliate"
,
"[Affiliate] Failed to invalidate billing cache for user %d: %v"
,
userID
,
err
)
}
}
}
// =========================
// Admin: 专属配置管理
// =========================
// validateExclusiveRate ensures a per-user override is finite and within
// [Min, Max]. nil is always valid (means "clear / fall back to global").
func
validateExclusiveRate
(
ratePercent
*
float64
)
error
{
if
ratePercent
==
nil
{
return
nil
}
v
:=
*
ratePercent
if
math
.
IsNaN
(
v
)
||
math
.
IsInf
(
v
,
0
)
{
return
infraerrors
.
BadRequest
(
"INVALID_RATE"
,
"invalid rebate rate"
)
}
if
v
<
AffiliateRebateRateMin
||
v
>
AffiliateRebateRateMax
{
return
infraerrors
.
BadRequest
(
"INVALID_RATE"
,
"rebate rate out of range"
)
}
return
nil
}
// AdminUpdateUserAffCode 管理员改写用户的邀请码(专属邀请码)。
func
(
s
*
AffiliateService
)
AdminUpdateUserAffCode
(
ctx
context
.
Context
,
userID
int64
,
rawCode
string
)
error
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"affiliate service unavailable"
)
}
code
:=
strings
.
ToUpper
(
strings
.
TrimSpace
(
rawCode
))
if
!
isValidAffiliateCodeFormat
(
code
)
{
return
ErrAffiliateCodeInvalid
}
return
s
.
repo
.
UpdateUserAffCode
(
ctx
,
userID
,
code
)
}
// AdminResetUserAffCode 重置用户邀请码为系统随机码。
func
(
s
*
AffiliateService
)
AdminResetUserAffCode
(
ctx
context
.
Context
,
userID
int64
)
(
string
,
error
)
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
""
,
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"affiliate service unavailable"
)
}
return
s
.
repo
.
ResetUserAffCode
(
ctx
,
userID
)
}
// AdminSetUserRebateRate 设置/清除用户专属返利比例。ratePercent==nil 表示清除。
func
(
s
*
AffiliateService
)
AdminSetUserRebateRate
(
ctx
context
.
Context
,
userID
int64
,
ratePercent
*
float64
)
error
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"affiliate service unavailable"
)
}
if
err
:=
validateExclusiveRate
(
ratePercent
);
err
!=
nil
{
return
err
}
return
s
.
repo
.
SetUserRebateRate
(
ctx
,
userID
,
ratePercent
)
}
// AdminBatchSetUserRebateRate 批量设置/清除用户专属返利比例。
func
(
s
*
AffiliateService
)
AdminBatchSetUserRebateRate
(
ctx
context
.
Context
,
userIDs
[]
int64
,
ratePercent
*
float64
)
error
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"affiliate service unavailable"
)
}
if
err
:=
validateExclusiveRate
(
ratePercent
);
err
!=
nil
{
return
err
}
cleaned
:=
make
([]
int64
,
0
,
len
(
userIDs
))
for
_
,
uid
:=
range
userIDs
{
if
uid
>
0
{
cleaned
=
append
(
cleaned
,
uid
)
}
}
if
len
(
cleaned
)
==
0
{
return
nil
}
return
s
.
repo
.
BatchSetUserRebateRate
(
ctx
,
cleaned
,
ratePercent
)
}
// AdminListCustomUsers 列出有专属配置的用户。
func
(
s
*
AffiliateService
)
AdminListCustomUsers
(
ctx
context
.
Context
,
filter
AffiliateAdminFilter
)
([]
AffiliateAdminEntry
,
int64
,
error
)
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
nil
,
0
,
infraerrors
.
ServiceUnavailable
(
"SERVICE_UNAVAILABLE"
,
"affiliate service unavailable"
)
}
return
s
.
repo
.
ListUsersWithCustomSettings
(
ctx
,
filter
)
}
backend/internal/service/affiliate_service_test.go
0 → 100644
View file @
3b7a5fff
//go:build unit
package
service
import
(
"context"
"math"
"testing"
"github.com/stretchr/testify/require"
)
// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter
// AffRebateRatePercent overrides the global rate, that NULL falls back to the
// global rate, and that out-of-range exclusive rates are clamped silently.
//
// SettingService is left nil here so globalRebateRatePercent returns the
// documented default (AffiliateRebateRateDefault = 20%) — this exercises the
// fallback path without spinning up a settings stub.
func
TestResolveRebateRatePercent_PerUserOverride
(
t
*
testing
.
T
)
{
t
.
Parallel
()
svc
:=
&
AffiliateService
{}
// nil exclusive rate → falls back to global default (20%)
require
.
InDelta
(
t
,
AffiliateRebateRateDefault
,
svc
.
resolveRebateRatePercent
(
context
.
Background
(),
&
AffiliateSummary
{}),
1e-9
)
// exclusive rate set → overrides global
rate
:=
50.0
require
.
InDelta
(
t
,
50.0
,
svc
.
resolveRebateRatePercent
(
context
.
Background
(),
&
AffiliateSummary
{
AffRebateRatePercent
:
&
rate
}),
1e-9
)
// exclusive rate 0 → returns 0 (no rebate, intentional)
zero
:=
0.0
require
.
InDelta
(
t
,
0.0
,
svc
.
resolveRebateRatePercent
(
context
.
Background
(),
&
AffiliateSummary
{
AffRebateRatePercent
:
&
zero
}),
1e-9
)
// exclusive rate above max → clamped to Max
tooHigh
:=
250.0
require
.
InDelta
(
t
,
AffiliateRebateRateMax
,
svc
.
resolveRebateRatePercent
(
context
.
Background
(),
&
AffiliateSummary
{
AffRebateRatePercent
:
&
tooHigh
}),
1e-9
)
// exclusive rate below min → clamped to Min
tooLow
:=
-
5.0
require
.
InDelta
(
t
,
AffiliateRebateRateMin
,
svc
.
resolveRebateRatePercent
(
context
.
Background
(),
&
AffiliateSummary
{
AffRebateRatePercent
:
&
tooLow
}),
1e-9
)
}
// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled
// safely handles a nil settingService dependency by returning the default
// (off). This protects callers from nil-pointer crashes in misconfigured
// environments.
func
TestIsEnabled_NilSettingServiceReturnsDefault
(
t
*
testing
.
T
)
{
t
.
Parallel
()
svc
:=
&
AffiliateService
{}
require
.
False
(
t
,
svc
.
IsEnabled
(
context
.
Background
()))
require
.
Equal
(
t
,
AffiliateEnabledDefault
,
svc
.
IsEnabled
(
context
.
Background
()))
}
// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by
// admin-facing rate setters: nil is always valid (clear), in-range values
// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest.
func
TestValidateExclusiveRate_BoundaryAndInvalid
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
NoError
(
t
,
validateExclusiveRate
(
nil
))
for
_
,
v
:=
range
[]
float64
{
0
,
0.01
,
50
,
99.99
,
100
}
{
v
:=
v
require
.
NoError
(
t
,
validateExclusiveRate
(
&
v
),
"value %v should be valid"
,
v
)
}
for
_
,
v
:=
range
[]
float64
{
-
0.01
,
100.01
,
-
100
,
200
}
{
v
:=
v
require
.
Error
(
t
,
validateExclusiveRate
(
&
v
),
"value %v should be rejected"
,
v
)
}
nan
:=
math
.
NaN
()
require
.
Error
(
t
,
validateExclusiveRate
(
&
nan
))
posInf
:=
math
.
Inf
(
1
)
require
.
Error
(
t
,
validateExclusiveRate
(
&
posInf
))
negInf
:=
math
.
Inf
(
-
1
)
require
.
Error
(
t
,
validateExclusiveRate
(
&
negInf
))
}
func
TestMaskEmail
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
Equal
(
t
,
"a***@g***.com"
,
maskEmail
(
"alice@gmail.com"
))
require
.
Equal
(
t
,
"x***@d***"
,
maskEmail
(
"x@domain"
))
require
.
Equal
(
t
,
""
,
maskEmail
(
""
))
}
func
TestIsValidAffiliateCodeFormat
(
t
*
testing
.
T
)
{
t
.
Parallel
()
// 邀请码格式校验同时服务于:
// 1) 系统自动生成的 12 位随机码(A-Z 去 I/O,2-9 去 0/1)
// 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1")
// 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper)。
cases
:=
[]
struct
{
name
string
in
string
want
bool
}{
{
"valid canonical 12-char"
,
"ABCDEFGHJKLM"
,
true
},
{
"valid all digits 2-9"
,
"234567892345"
,
true
},
{
"valid mixed"
,
"A2B3C4D5E6F7"
,
true
},
{
"valid admin custom short"
,
"VIP1"
,
true
},
{
"valid admin custom with hyphen"
,
"NEW-USER"
,
true
},
{
"valid admin custom with underscore"
,
"VIP_2026"
,
true
},
{
"valid 32-char max"
,
"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
,
true
},
// Previously-excluded chars (I/O/0/1) are now allowed since admins may use them.
{
"letter I now allowed"
,
"IBCDEFGHJKLM"
,
true
},
{
"letter O now allowed"
,
"OBCDEFGHJKLM"
,
true
},
{
"digit 0 now allowed"
,
"0BCDEFGHJKLM"
,
true
},
{
"digit 1 now allowed"
,
"1BCDEFGHJKLM"
,
true
},
{
"too short (3 chars)"
,
"ABC"
,
false
},
{
"too long (33 chars)"
,
"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456"
,
false
},
{
"lowercase rejected (caller must ToUpper first)"
,
"abcdefghjklm"
,
false
},
{
"empty"
,
""
,
false
},
{
"utf8 non-ascii"
,
"ÄÄÄÄÄÄ"
,
false
},
// bytes out of charset
{
"ascii punctuation ."
,
"ABCDEFGHJK.M"
,
false
},
{
"whitespace"
,
"ABCDEFGHJK M"
,
false
},
}
for
_
,
tc
:=
range
cases
{
tc
:=
tc
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
Equal
(
t
,
tc
.
want
,
isValidAffiliateCodeFormat
(
tc
.
in
))
})
}
}
backend/internal/service/antigravity_gateway_service.go
View file @
3b7a5fff
...
@@ -21,6 +21,7 @@ import (
...
@@ -21,6 +21,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/google/uuid"
...
@@ -1739,6 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -1739,6 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var
usage
*
ClaudeUsage
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
var
firstTokenMs
*
int
var
clientDisconnect
bool
var
clientDisconnect
bool
var
responseBody
string
if
claudeReq
.
Stream
{
if
claudeReq
.
Stream
{
// 客户端要求流式,直接透传转换
// 客户端要求流式,直接透传转换
streamRes
,
err
:=
s
.
handleClaudeStreamingResponse
(
c
,
resp
,
startTime
,
originalModel
)
streamRes
,
err
:=
s
.
handleClaudeStreamingResponse
(
c
,
resp
,
startTime
,
originalModel
)
...
@@ -1749,6 +1751,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -1749,6 +1751,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
usage
=
streamRes
.
usage
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
firstTokenMs
=
streamRes
.
firstTokenMs
clientDisconnect
=
streamRes
.
clientDisconnect
clientDisconnect
=
streamRes
.
clientDisconnect
// 流式:从 context buffer 读取采集的文本
if
captureBuilder
,
ok
:=
ctx
.
Value
(
ctxkey
.
ResponseCaptureBuffer
)
.
(
*
strings
.
Builder
);
ok
&&
captureBuilder
!=
nil
{
responseBody
=
captureBuilder
.
String
()
}
}
else
{
}
else
{
// 客户端要求非流式,收集流式响应后转换返回
// 客户端要求非流式,收集流式响应后转换返回
streamRes
,
err
:=
s
.
handleClaudeStreamToNonStreaming
(
c
,
resp
,
startTime
,
originalModel
)
streamRes
,
err
:=
s
.
handleClaudeStreamToNonStreaming
(
c
,
resp
,
startTime
,
originalModel
)
...
@@ -1758,6 +1764,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -1758,6 +1764,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
}
usage
=
streamRes
.
usage
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
firstTokenMs
=
streamRes
.
firstTokenMs
responseBody
=
streamRes
.
responseBody
}
}
return
&
ForwardResult
{
return
&
ForwardResult
{
...
@@ -1769,6 +1776,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -1769,6 +1776,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Duration
:
time
.
Since
(
startTime
),
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
FirstTokenMs
:
firstTokenMs
,
ClientDisconnect
:
clientDisconnect
,
ClientDisconnect
:
clientDisconnect
,
ResponseBody
:
responseBody
,
},
nil
},
nil
}
}
...
@@ -2421,6 +2429,7 @@ handleSuccess:
...
@@ -2421,6 +2429,7 @@ handleSuccess:
var
usage
*
ClaudeUsage
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
var
firstTokenMs
*
int
var
clientDisconnect
bool
var
clientDisconnect
bool
var
responseBody
string
if
stream
{
if
stream
{
// 客户端要求流式,直接透传
// 客户端要求流式,直接透传
...
@@ -2432,6 +2441,10 @@ handleSuccess:
...
@@ -2432,6 +2441,10 @@ handleSuccess:
usage
=
streamRes
.
usage
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
firstTokenMs
=
streamRes
.
firstTokenMs
clientDisconnect
=
streamRes
.
clientDisconnect
clientDisconnect
=
streamRes
.
clientDisconnect
// 流式:从 context buffer 读取采集的文本
if
captureBuilder
,
ok
:=
ctx
.
Value
(
ctxkey
.
ResponseCaptureBuffer
)
.
(
*
strings
.
Builder
);
ok
&&
captureBuilder
!=
nil
{
responseBody
=
captureBuilder
.
String
()
}
}
else
{
}
else
{
// 客户端要求非流式,收集流式响应后返回
// 客户端要求非流式,收集流式响应后返回
streamRes
,
err
:=
s
.
handleGeminiStreamToNonStreaming
(
c
,
resp
,
startTime
)
streamRes
,
err
:=
s
.
handleGeminiStreamToNonStreaming
(
c
,
resp
,
startTime
)
...
@@ -2441,6 +2454,7 @@ handleSuccess:
...
@@ -2441,6 +2454,7 @@ handleSuccess:
}
}
usage
=
streamRes
.
usage
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
firstTokenMs
=
streamRes
.
firstTokenMs
responseBody
=
streamRes
.
responseBody
}
}
if
usage
==
nil
{
if
usage
==
nil
{
...
@@ -2465,6 +2479,7 @@ handleSuccess:
...
@@ -2465,6 +2479,7 @@ handleSuccess:
ClientDisconnect
:
clientDisconnect
,
ClientDisconnect
:
clientDisconnect
,
ImageCount
:
imageCount
,
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
ImageSize
:
imageSize
,
ResponseBody
:
responseBody
,
},
nil
},
nil
}
}
...
@@ -2968,7 +2983,8 @@ func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur
...
@@ -2968,7 +2983,8 @@ func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur
type
antigravityStreamResult
struct
{
type
antigravityStreamResult
struct
{
usage
*
ClaudeUsage
usage
*
ClaudeUsage
firstTokenMs
*
int
firstTokenMs
*
int
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
responseBody
string
// 响应体内容(非流式:完整 JSON;流式:从上下文 buffer 读取)
}
}
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
...
@@ -3124,6 +3140,9 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
...
@@ -3124,6 +3140,9 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity gemini"
)
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity gemini"
)
// 响应体文本采集:若上下文注入了 ResponseCaptureBuffer,则写入文本内容
captureBuilder
,
_
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ResponseCaptureBuffer
)
.
(
*
strings
.
Builder
)
// 仅发送一次错误事件,避免多次写入导致协议混乱
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
sendErrorEvent
:=
func
(
reason
string
)
{
...
@@ -3197,6 +3216,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
...
@@ -3197,6 +3216,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs
=
&
ms
firstTokenMs
=
&
ms
}
}
// 采集文本用于响应捕获
if
captureBuilder
!=
nil
&&
len
(
inner
)
>
0
{
gjson
.
GetBytes
(
inner
,
"candidates.0.content.parts"
)
.
ForEach
(
func
(
_
,
v
gjson
.
Result
)
bool
{
if
t
:=
v
.
Get
(
"text"
)
.
String
();
t
!=
""
{
captureBuilder
.
WriteString
(
t
)
}
return
true
})
}
cw
.
Fprintf
(
"data: %s
\n\n
"
,
payload
)
cw
.
Fprintf
(
"data: %s
\n\n
"
,
payload
)
continue
continue
}
}
...
@@ -3418,7 +3447,7 @@ returnResponse:
...
@@ -3418,7 +3447,7 @@ returnResponse:
}
}
c
.
Data
(
http
.
StatusOK
,
"application/json"
,
respBody
)
c
.
Data
(
http
.
StatusOK
,
"application/json"
,
respBody
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
responseBody
:
strings
.
Join
(
collectedTextParts
,
""
)
},
nil
}
}
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
...
@@ -3867,7 +3896,7 @@ returnResponse:
...
@@ -3867,7 +3896,7 @@ returnResponse:
CacheReadInputTokens
:
agUsage
.
CacheReadInputTokens
,
CacheReadInputTokens
:
agUsage
.
CacheReadInputTokens
,
}
}
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
responseBody
:
string
(
claudeResp
)
},
nil
}
}
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
...
@@ -3971,6 +4000,9 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -3971,6 +4000,9 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity claude"
)
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity claude"
)
// 响应体文本采集:若上下文注入了 ResponseCaptureBuffer,则写入文本 delta
captureBuilder
,
_
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
ResponseCaptureBuffer
)
.
(
*
strings
.
Builder
)
// 仅发送一次错误事件,避免多次写入导致协议混乱
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
sendErrorEvent
:=
func
(
reason
string
)
{
...
@@ -4024,7 +4056,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -4024,7 +4056,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
lastDataAt
=
time
.
Now
()
lastDataAt
=
time
.
Now
()
// 处理 SSE 行,转换为 Claude 格式
// 处理 SSE 行,转换为 Claude 格式
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
ev
.
line
,
"
\r\n
"
))
trimmedLine
:=
strings
.
TrimRight
(
ev
.
line
,
"
\r\n
"
)
claudeEvents
:=
processor
.
ProcessLine
(
trimmedLine
)
if
len
(
claudeEvents
)
>
0
{
if
len
(
claudeEvents
)
>
0
{
if
firstTokenMs
==
nil
{
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
...
@@ -4033,6 +4066,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -4033,6 +4066,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
cw
.
Write
(
claudeEvents
)
cw
.
Write
(
claudeEvents
)
}
}
// 采集文本用于响应捕获
if
captureBuilder
!=
nil
&&
strings
.
HasPrefix
(
trimmedLine
,
"data:"
)
{
data
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmedLine
,
"data:"
))
if
data
!=
""
&&
data
!=
"[DONE]"
{
// v1internal 格式: response.candidates.0.content.parts.0.text
// 直接 Gemini 格式: candidates.0.content.parts.0.text
text
:=
gjson
.
Get
(
data
,
"response.candidates.0.content.parts.0.text"
)
.
String
()
if
text
==
""
{
text
=
gjson
.
Get
(
data
,
"candidates.0.content.parts.0.text"
)
.
String
()
}
if
text
!=
""
{
captureBuilder
.
WriteString
(
text
)
}
}
}
case
<-
intervalCh
:
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
if
time
.
Since
(
lastRead
)
<
streamInterval
{
...
...
backend/internal/service/auth_oauth_email_flow.go
View file @
3b7a5fff
...
@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
...
@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
user
*
User
,
user
*
User
,
invitationCode
string
,
invitationCode
string
,
signupSource
string
,
signupSource
string
,
affiliateCode
string
,
)
error
{
)
error
{
if
s
==
nil
||
user
==
nil
||
user
.
ID
<=
0
{
if
s
==
nil
||
user
==
nil
||
user
.
ID
<=
0
{
return
ErrServiceUnavailable
return
ErrServiceUnavailable
...
@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
...
@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
s
.
updateOAuthSignupSource
(
ctx
,
user
.
ID
,
signupSource
)
s
.
updateOAuthSignupSource
(
ctx
,
user
.
ID
,
signupSource
)
grantPlan
:=
s
.
resolveSignupGrantPlan
(
ctx
,
signupSource
)
grantPlan
:=
s
.
resolveSignupGrantPlan
(
ctx
,
signupSource
)
s
.
assignSubscriptions
(
ctx
,
user
.
ID
,
grantPlan
.
Subscriptions
,
"auto assigned by signup defaults"
)
s
.
assignSubscriptions
(
ctx
,
user
.
ID
,
grantPlan
.
Subscriptions
,
"auto assigned by signup defaults"
)
s
.
bindOAuthAffiliate
(
ctx
,
user
.
ID
,
affiliateCode
)
return
nil
return
nil
}
}
...
...
backend/internal/service/auth_oauth_email_flow_test.go
View file @
3b7a5fff
...
@@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService(
...
@@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
)
}
}
...
...
backend/internal/service/auth_service.go
View file @
3b7a5fff
...
@@ -72,6 +72,7 @@ type AuthService struct {
...
@@ -72,6 +72,7 @@ type AuthService struct {
turnstileService
*
TurnstileService
turnstileService
*
TurnstileService
emailQueueService
*
EmailQueueService
emailQueueService
*
EmailQueueService
promoService
*
PromoService
promoService
*
PromoService
affiliateService
*
AffiliateService
defaultSubAssigner
DefaultSubscriptionAssigner
defaultSubAssigner
DefaultSubscriptionAssigner
}
}
...
@@ -98,6 +99,7 @@ func NewAuthService(
...
@@ -98,6 +99,7 @@ func NewAuthService(
emailQueueService
*
EmailQueueService
,
emailQueueService
*
EmailQueueService
,
promoService
*
PromoService
,
promoService
*
PromoService
,
defaultSubAssigner
DefaultSubscriptionAssigner
,
defaultSubAssigner
DefaultSubscriptionAssigner
,
affiliateService
*
AffiliateService
,
)
*
AuthService
{
)
*
AuthService
{
return
&
AuthService
{
return
&
AuthService
{
entClient
:
entClient
,
entClient
:
entClient
,
...
@@ -110,6 +112,7 @@ func NewAuthService(
...
@@ -110,6 +112,7 @@ func NewAuthService(
turnstileService
:
turnstileService
,
turnstileService
:
turnstileService
,
emailQueueService
:
emailQueueService
,
emailQueueService
:
emailQueueService
,
promoService
:
promoService
,
promoService
:
promoService
,
affiliateService
:
affiliateService
,
defaultSubAssigner
:
defaultSubAssigner
,
defaultSubAssigner
:
defaultSubAssigner
,
}
}
}
}
...
@@ -123,11 +126,11 @@ func (s *AuthService) EntClient() *dbent.Client {
...
@@ -123,11 +126,11 @@ func (s *AuthService) EntClient() *dbent.Client {
// Register 用户注册,返回token和用户
// Register 用户注册,返回token和用户
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
,
""
)
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
,
""
,
""
)
}
}
// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
// RegisterWithVerification 用户注册(支持邮件验证、优惠码
、邀请码
和邀请
返利
码),返回token和用户
。
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
,
invitationCode
string
)
(
string
,
*
User
,
error
)
{
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
,
invitationCode
,
affiliateCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
return
""
,
nil
,
ErrRegDisabled
...
@@ -223,6 +226,17 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
...
@@ -223,6 +226,17 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
}
s
.
postAuthUserBootstrap
(
ctx
,
user
,
"email"
,
true
)
s
.
postAuthUserBootstrap
(
ctx
,
user
,
"email"
,
true
)
s
.
assignSubscriptions
(
ctx
,
user
.
ID
,
grantPlan
.
Subscriptions
,
"auto assigned by signup defaults"
)
s
.
assignSubscriptions
(
ctx
,
user
.
ID
,
grantPlan
.
Subscriptions
,
"auto assigned by signup defaults"
)
if
s
.
affiliateService
!=
nil
{
if
_
,
err
:=
s
.
affiliateService
.
EnsureUserAffiliate
(
ctx
,
user
.
ID
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to initialize affiliate profile for user %d: %v"
,
user
.
ID
,
err
)
}
if
code
:=
strings
.
TrimSpace
(
affiliateCode
);
code
!=
""
{
if
err
:=
s
.
affiliateService
.
BindInviterByCode
(
ctx
,
user
.
ID
,
code
);
err
!=
nil
{
// 邀请返利码绑定失败不影响注册,只记录日志
logger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to bind affiliate inviter for user %d: %v"
,
user
.
ID
,
err
)
}
}
}
// 标记邀请码为已使用(如果使用了邀请码)
// 标记邀请码为已使用(如果使用了邀请码)
if
invitationRedeemCode
!=
nil
{
if
invitationRedeemCode
!=
nil
{
...
@@ -549,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
...
@@ -549,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
func
(
s
*
AuthService
)
LoginOrRegisterOAuthWithTokenPair
(
ctx
context
.
Context
,
email
,
username
,
invitationCode
string
)
(
*
TokenPair
,
*
User
,
error
)
{
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
func
(
s
*
AuthService
)
LoginOrRegisterOAuthWithTokenPair
(
ctx
context
.
Context
,
email
,
username
,
invitationCode
,
affiliateCode
string
)
(
*
TokenPair
,
*
User
,
error
)
{
// 检查 refreshTokenCache 是否可用
// 检查 refreshTokenCache 是否可用
if
s
.
refreshTokenCache
==
nil
{
if
s
.
refreshTokenCache
==
nil
{
return
nil
,
nil
,
errors
.
New
(
"refresh token cache not configured"
)
return
nil
,
nil
,
errors
.
New
(
"refresh token cache not configured"
)
...
@@ -652,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
...
@@ -652,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user
=
newUser
user
=
newUser
s
.
postAuthUserBootstrap
(
ctx
,
user
,
signupSource
,
false
)
s
.
postAuthUserBootstrap
(
ctx
,
user
,
signupSource
,
false
)
s
.
assignSubscriptions
(
ctx
,
user
.
ID
,
grantPlan
.
Subscriptions
,
"auto assigned by signup defaults"
)
s
.
assignSubscriptions
(
ctx
,
user
.
ID
,
grantPlan
.
Subscriptions
,
"auto assigned by signup defaults"
)
s
.
bindOAuthAffiliate
(
ctx
,
user
.
ID
,
affiliateCode
)
}
}
}
else
{
}
else
{
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
newUser
);
err
!=
nil
{
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
newUser
);
err
!=
nil
{
...
@@ -669,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
...
@@ -669,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user
=
newUser
user
=
newUser
s
.
postAuthUserBootstrap
(
ctx
,
user
,
signupSource
,
false
)
s
.
postAuthUserBootstrap
(
ctx
,
user
,
signupSource
,
false
)
s
.
assignSubscriptions
(
ctx
,
user
.
ID
,
grantPlan
.
Subscriptions
,
"auto assigned by signup defaults"
)
s
.
assignSubscriptions
(
ctx
,
user
.
ID
,
grantPlan
.
Subscriptions
,
"auto assigned by signup defaults"
)
s
.
bindOAuthAffiliate
(
ctx
,
user
.
ID
,
affiliateCode
)
if
invitationRedeemCode
!=
nil
{
if
invitationRedeemCode
!=
nil
{
if
err
:=
s
.
redeemRepo
.
Use
(
ctx
,
invitationRedeemCode
.
ID
,
user
.
ID
);
err
!=
nil
{
if
err
:=
s
.
redeemRepo
.
Use
(
ctx
,
invitationRedeemCode
.
ID
,
user
.
ID
);
err
!=
nil
{
return
nil
,
nil
,
ErrInvitationCodeInvalid
return
nil
,
nil
,
ErrInvitationCodeInvalid
...
@@ -763,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
...
@@ -763,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
}
}
}
}
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
// for an OAuth-registered user. Failures are logged but never block registration.
func
(
s
*
AuthService
)
bindOAuthAffiliate
(
ctx
context
.
Context
,
userID
int64
,
affiliateCode
string
)
{
if
s
.
affiliateService
==
nil
||
userID
<=
0
{
return
}
if
_
,
err
:=
s
.
affiliateService
.
EnsureUserAffiliate
(
ctx
,
userID
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to initialize affiliate profile for user %d: %v"
,
userID
,
err
)
}
if
code
:=
strings
.
TrimSpace
(
affiliateCode
);
code
!=
""
{
if
err
:=
s
.
affiliateService
.
BindInviterByCode
(
ctx
,
userID
,
code
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to bind affiliate inviter for user %d: %v"
,
userID
,
err
)
}
}
}
func
(
s
*
AuthService
)
postAuthUserBootstrap
(
ctx
context
.
Context
,
user
*
User
,
signupSource
string
,
touchLogin
bool
)
{
func
(
s
*
AuthService
)
postAuthUserBootstrap
(
ctx
context
.
Context
,
user
*
User
,
signupSource
string
,
touchLogin
bool
)
{
if
user
==
nil
||
user
.
ID
<=
0
{
if
user
==
nil
||
user
.
ID
<=
0
{
return
return
...
...
backend/internal/service/auth_service_email_bind_test.go
View file @
3b7a5fff
...
@@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
...
@@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc
=
service
.
NewEmailService
(
settingRepo
,
emailCache
)
emailSvc
=
service
.
NewEmailService
(
settingRepo
,
emailCache
)
}
}
svc
:=
service
.
NewAuthService
(
client
,
repo
,
nil
,
refreshTokenCache
,
cfg
,
settingSvc
,
emailSvc
,
nil
,
nil
,
nil
,
defaultSubAssigner
)
svc
:=
service
.
NewAuthService
(
client
,
repo
,
nil
,
refreshTokenCache
,
cfg
,
settingSvc
,
emailSvc
,
nil
,
nil
,
nil
,
defaultSubAssigner
,
nil
)
return
svc
,
repo
,
client
return
svc
,
repo
,
client
}
}
...
@@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
...
@@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
},
},
}
}
emailService
:=
service
.
NewEmailService
(
nil
,
cache
)
emailService
:=
service
.
NewEmailService
(
nil
,
cache
)
svc
:=
service
.
NewAuthService
(
nil
,
userRepo
,
nil
,
refreshTokenCache
,
cfg
,
nil
,
emailService
,
nil
,
nil
,
nil
,
nil
)
svc
:=
service
.
NewAuthService
(
nil
,
userRepo
,
nil
,
refreshTokenCache
,
cfg
,
nil
,
emailService
,
nil
,
nil
,
nil
,
nil
,
nil
)
oldTokenPair
,
err
:=
svc
.
GenerateTokenPair
(
ctx
,
&
service
.
User
{
oldTokenPair
,
err
:=
svc
.
GenerateTokenPair
(
ctx
,
&
service
.
User
{
ID
:
41
,
ID
:
41
,
...
...
backend/internal/service/auth_service_identity_sync_test.go
View file @
3b7a5fff
...
@@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
...
@@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
values
:
settings
,
values
:
settings
,
},
cfg
)
},
cfg
)
svc
:=
service
.
NewAuthService
(
client
,
repo
,
nil
,
nil
,
cfg
,
settingSvc
,
nil
,
nil
,
nil
,
nil
,
defaultSubAssigner
)
svc
:=
service
.
NewAuthService
(
client
,
repo
,
nil
,
nil
,
cfg
,
settingSvc
,
nil
,
nil
,
nil
,
nil
,
defaultSubAssigner
,
nil
)
return
svc
,
repo
,
client
return
svc
,
repo
,
client
}
}
...
...
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment