Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6b97a8be
Commit
6b97a8be
authored
Jan 09, 2026
by
Edric Li
Browse files
Merge branch 'main' into feat/api-key-ip-restriction
parents
90798f14
62dc0b95
Changes
70
Show whitespace changes
Inline
Side-by-side
backend/internal/pkg/geminicli/oauth_test.go
View file @
6b97a8be
...
...
@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr
:
false
,
},
{
name
:
"Google One
with custom client
"
,
name
:
"Google One
always uses built-in client (even if custom credentials passed)
"
,
input
:
OAuthConfig
{
ClientID
:
"custom-client-id"
,
ClientSecret
:
"custom-client-secret"
,
},
oauthType
:
"google_one"
,
wantClientID
:
"custom-client-id"
,
wantScopes
:
Default
GoogleOneScopes
,
wantScopes
:
Default
CodeAssistScopes
,
// Uses code assist scopes even with custom client
wantErr
:
false
,
},
{
...
...
backend/internal/repository/account_repo.go
View file @
6b97a8be
...
...
@@ -831,6 +831,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args
=
append
(
args
,
*
updates
.
Status
)
idx
++
}
if
updates
.
Schedulable
!=
nil
{
setClauses
=
append
(
setClauses
,
"schedulable = $"
+
itoa
(
idx
))
args
=
append
(
args
,
*
updates
.
Schedulable
)
idx
++
}
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if
len
(
updates
.
Credentials
)
>
0
{
payload
,
err
:=
json
.
Marshal
(
updates
.
Credentials
)
...
...
backend/internal/repository/gemini_oauth_client.go
View file @
6b97a8be
...
...
@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one:
uses configured OAuth client when provided; otherwise falls back to built-in client
// - google_one:
always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfgInput
:=
geminicli
.
OAuthConfig
{
ClientID
:
c
.
cfg
.
Gemini
.
OAuth
.
ClientID
,
ClientSecret
:
c
.
cfg
.
Gemini
.
OAuth
.
ClientSecret
,
Scopes
:
c
.
cfg
.
Gemini
.
OAuth
.
Scopes
,
}
if
oauthType
==
"code_assist"
{
if
oauthType
==
"code_assist"
||
oauthType
==
"google_one"
{
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput
.
ClientID
=
""
oauthCfgInput
.
ClientSecret
=
""
}
...
...
@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
ClientSecret
:
c
.
cfg
.
Gemini
.
OAuth
.
ClientSecret
,
Scopes
:
c
.
cfg
.
Gemini
.
OAuth
.
Scopes
,
}
if
oauthType
==
"code_assist"
{
if
oauthType
==
"code_assist"
||
oauthType
==
"google_one"
{
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput
.
ClientID
=
""
oauthCfgInput
.
ClientSecret
=
""
}
...
...
backend/internal/repository/group_repo.go
View file @
6b97a8be
...
...
@@ -112,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
}
func
(
r
*
groupRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
nil
)
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
,
nil
)
}
func
(
r
*
groupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
groupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
q
:=
r
.
client
.
Group
.
Query
()
if
platform
!=
""
{
...
...
@@ -124,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
if
status
!=
""
{
q
=
q
.
Where
(
group
.
StatusEQ
(
status
))
}
if
search
!=
""
{
q
=
q
.
Where
(
group
.
Or
(
group
.
NameContainsFold
(
search
),
group
.
DescriptionContainsFold
(
search
),
))
}
if
isExclusive
!=
nil
{
q
=
q
.
Where
(
group
.
IsExclusiveEQ
(
*
isExclusive
))
}
...
...
backend/internal/repository/group_repo_integration_test.go
View file @
6b97a8be
...
...
@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
""
,
nil
,
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters base"
)
...
...
@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}))
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
nil
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
len
(
baseGroups
)
+
1
)
// Verify all groups are OpenAI platform
...
...
@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}))
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
service
.
StatusDisabled
,
nil
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
service
.
StatusDisabled
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
groups
[
0
]
.
Status
)
...
...
@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
}))
isExclusive
:=
true
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
&
isExclusive
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
""
,
&
isExclusive
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
True
(
groups
[
0
]
.
IsExclusive
)
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Search
()
{
newRepo
:=
func
()
(
*
groupRepository
,
context
.
Context
)
{
tx
:=
testEntTx
(
s
.
T
())
return
newGroupRepositoryWithSQL
(
tx
.
Client
(),
tx
),
context
.
Background
()
}
containsID
:=
func
(
groups
[]
service
.
Group
,
id
int64
)
bool
{
for
i
:=
range
groups
{
if
groups
[
i
]
.
ID
==
id
{
return
true
}
}
return
false
}
mustCreate
:=
func
(
repo
*
groupRepository
,
ctx
context
.
Context
,
g
*
service
.
Group
)
*
service
.
Group
{
s
.
Require
()
.
NoError
(
repo
.
Create
(
ctx
,
g
))
s
.
Require
()
.
NotZero
(
g
.
ID
)
return
g
}
newGroup
:=
func
(
name
string
)
*
service
.
Group
{
return
&
service
.
Group
{
Name
:
name
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.0
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}
}
s
.
Run
(
"search_name_should_match"
,
func
()
{
repo
,
ctx
:=
newRepo
()
target
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-name-target"
))
other
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-name-other"
))
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"name-target"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
target
.
ID
),
"expected target group to match by name"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
other
.
ID
),
"expected other group to be filtered out"
)
})
s
.
Run
(
"search_description_should_match"
,
func
()
{
repo
,
ctx
:=
newRepo
()
target
:=
newGroup
(
"it-group-search-desc-target"
)
target
.
Description
=
"something about desc-needle in here"
target
=
mustCreate
(
repo
,
ctx
,
target
)
other
:=
newGroup
(
"it-group-search-desc-other"
)
other
.
Description
=
"nothing to see here"
other
=
mustCreate
(
repo
,
ctx
,
other
)
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"desc-needle"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
target
.
ID
),
"expected target group to match by description"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
other
.
ID
),
"expected other group to be filtered out"
)
})
s
.
Run
(
"search_nonexistent_should_return_empty"
,
func
()
{
repo
,
ctx
:=
newRepo
()
_
=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-nonexistent-baseline"
))
search
:=
s
.
T
()
.
Name
()
+
"__no_such_group__"
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
search
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Empty
(
groups
)
})
s
.
Run
(
"search_should_be_case_insensitive"
,
func
()
{
repo
,
ctx
:=
newRepo
()
target
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"MiXeDCaSe-Needle"
))
other
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-case-other"
))
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"mixedcase-needle"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
target
.
ID
),
"expected case-insensitive match"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
other
.
ID
),
"expected other group to be filtered out"
)
})
s
.
Run
(
"search_should_escape_like_wildcards"
,
func
()
{
repo
,
ctx
:=
newRepo
()
percentTarget
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-100%-target"
))
percentOther
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-100X-other"
))
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"100%"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
percentTarget
.
ID
),
"expected literal %% match"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
percentOther
.
ID
),
"expected %% not to act as wildcard"
)
underscoreTarget
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-ab_cd-target"
))
underscoreOther
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-abXcd-other"
))
groups
,
_
,
err
=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"ab_cd"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
underscoreTarget
.
ID
),
"expected literal _ match"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
underscoreOther
.
ID
),
"expected _ not to act as wildcard"
)
})
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_AccountCount
()
{
g1
:=
&
service
.
Group
{
Name
:
"g1"
,
...
...
@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
s
.
Require
()
.
NoError
(
err
)
isExclusive
:=
true
groups
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformAnthropic
,
service
.
StatusActive
,
&
isExclusive
)
groups
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformAnthropic
,
service
.
StatusActive
,
""
,
&
isExclusive
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
s
.
Require
()
.
Len
(
groups
,
1
)
...
...
backend/internal/server/api_contract_test.go
View file @
6b97a8be
...
...
@@ -304,6 +304,10 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
...
...
@@ -390,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
)
...
...
@@ -583,7 +587,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubGroupRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
stubGroupRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/routes/auth.go
View file @
6b97a8be
...
...
@@ -19,6 +19,8 @@ func RegisterAuthRoutes(
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
auth
.
GET
(
"/oauth/linuxdo/start"
,
h
.
Auth
.
LinuxDoOAuthStart
)
auth
.
GET
(
"/oauth/linuxdo/callback"
,
h
.
Auth
.
LinuxDoOAuthCallback
)
}
// 公开设置(无需认证)
...
...
backend/internal/service/account_service.go
View file @
6b97a8be
...
...
@@ -66,6 +66,7 @@ type AccountBulkUpdate struct {
Concurrency
*
int
Priority
*
int
Status
*
string
Schedulable
*
bool
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
}
...
...
backend/internal/service/account_test_service.go
View file @
6b97a8be
...
...
@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
if
candidates
,
ok
:=
data
[
"candidates"
]
.
([]
any
);
ok
&&
len
(
candidates
)
>
0
{
if
candidate
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
);
ok
{
// Check for completion
if
finishReason
,
ok
:=
candidate
[
"finishReason"
]
.
(
string
);
ok
&&
finishReason
!=
""
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
}
// Extract content
// Extract content first (before checking completion)
if
content
,
ok
:=
candidate
[
"content"
]
.
(
map
[
string
]
any
);
ok
{
if
parts
,
ok
:=
content
[
"parts"
]
.
([]
any
);
ok
{
for
_
,
part
:=
range
parts
{
...
...
@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
}
}
// Check for completion after extracting content
if
finishReason
,
ok
:=
candidate
[
"finishReason"
]
.
(
string
);
ok
&&
finishReason
!=
""
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
}
}
}
...
...
backend/internal/service/admin_service.go
View file @
6b97a8be
...
...
@@ -24,7 +24,7 @@ type AdminService interface {
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
any
,
error
)
// Group management
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
GetAllGroups
(
ctx
context
.
Context
)
([]
Group
,
error
)
GetAllGroupsByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
GetGroup
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
...
...
@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct {
Concurrency
*
int
Priority
*
int
Status
string
Schedulable
*
bool
GroupIDs
*
[]
int64
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
...
...
@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
}
// Group management implementations
func
(
s
*
adminServiceImpl
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
groups
,
result
,
err
:=
s
.
groupRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
status
,
isExclusive
)
groups
,
result
,
err
:=
s
.
groupRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
status
,
search
,
isExclusive
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
...
...
@@ -910,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if
input
.
Status
!=
""
{
repoUpdates
.
Status
=
&
input
.
Status
}
if
input
.
Schedulable
!=
nil
{
repoUpdates
.
Schedulable
=
input
.
Schedulable
}
// Run bulk update for column/jsonb fields first.
if
_
,
err
:=
s
.
accountRepo
.
BulkUpdate
(
ctx
,
input
.
AccountIDs
,
repoUpdates
);
err
!=
nil
{
...
...
backend/internal/service/admin_service_delete_test.go
View file @
6b97a8be
...
...
@@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStub
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
groupRepoStub
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
...
...
backend/internal/service/admin_service_group_test.go
View file @
6b97a8be
...
...
@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
updated
*
Group
// 记录 Update 调用的参数
getByID
*
Group
// GetByID 返回值
getErr
error
// GetByID 返回的错误
listWithFiltersCalls
int
listWithFiltersParams
pagination
.
PaginationParams
listWithFiltersPlatform
string
listWithFiltersStatus
string
listWithFiltersSearch
string
listWithFiltersIsExclusive
*
bool
listWithFiltersGroups
[]
Group
listWithFiltersResult
*
pagination
.
PaginationResult
listWithFiltersErr
error
}
func
(
s
*
groupRepoStubForAdmin
)
Create
(
_
context
.
Context
,
g
*
Group
)
error
{
...
...
@@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForAdmin
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
func
(
s
*
groupRepoStubForAdmin
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersCalls
++
s
.
listWithFiltersParams
=
params
s
.
listWithFiltersPlatform
=
platform
s
.
listWithFiltersStatus
=
status
s
.
listWithFiltersSearch
=
search
s
.
listWithFiltersIsExclusive
=
isExclusive
if
s
.
listWithFiltersErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersErr
}
result
:=
s
.
listWithFiltersResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersGroups
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersGroups
,
result
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
...
...
@@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require
.
InDelta
(
t
,
0.15
,
*
repo
.
updated
.
ImagePrice2K
,
0.0001
)
// 原值保持
require
.
Nil
(
t
,
repo
.
updated
.
ImagePrice4K
)
}
func
TestAdminService_ListGroups_WithSearch
(
t
*
testing
.
T
)
{
// 测试:
// 1. search 参数正常传递到 repository 层
// 2. search 为空字符串时的行为
// 3. search 与其他过滤条件组合使用
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{
listWithFiltersGroups
:
[]
Group
{{
ID
:
1
,
Name
:
"alpha"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
1
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
1
,
20
,
""
,
""
,
"alpha"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
total
)
require
.
Equal
(
t
,
[]
Group
{{
ID
:
1
,
Name
:
"alpha"
}},
groups
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
"alpha"
,
repo
.
listWithFiltersSearch
)
require
.
Nil
(
t
,
repo
.
listWithFiltersIsExclusive
)
})
t
.
Run
(
"search 为空字符串时传递空字符串"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{
listWithFiltersGroups
:
[]
Group
{},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
0
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
2
,
10
,
""
,
""
,
""
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
groups
)
require
.
Equal
(
t
,
int64
(
0
),
total
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
2
,
PageSize
:
10
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
""
,
repo
.
listWithFiltersSearch
)
require
.
Nil
(
t
,
repo
.
listWithFiltersIsExclusive
)
})
t
.
Run
(
"search 与其他过滤条件组合使用"
,
func
(
t
*
testing
.
T
)
{
isExclusive
:=
true
repo
:=
&
groupRepoStubForAdmin
{
listWithFiltersGroups
:
[]
Group
{{
ID
:
2
,
Name
:
"beta"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
42
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
3
,
50
,
PlatformAntigravity
,
StatusActive
,
"beta"
,
&
isExclusive
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
42
),
total
)
require
.
Equal
(
t
,
[]
Group
{{
ID
:
2
,
Name
:
"beta"
}},
groups
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
3
,
PageSize
:
50
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
PlatformAntigravity
,
repo
.
listWithFiltersPlatform
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"beta"
,
repo
.
listWithFiltersSearch
)
require
.
NotNil
(
t
,
repo
.
listWithFiltersIsExclusive
)
require
.
True
(
t
,
*
repo
.
listWithFiltersIsExclusive
)
})
}
backend/internal/service/admin_service_search_test.go
0 → 100644
View file @
6b97a8be
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type
accountRepoStubForAdminList
struct
{
accountRepoStub
listWithFiltersCalls
int
listWithFiltersParams
pagination
.
PaginationParams
listWithFiltersPlatform
string
listWithFiltersType
string
listWithFiltersStatus
string
listWithFiltersSearch
string
listWithFiltersAccounts
[]
Account
listWithFiltersResult
*
pagination
.
PaginationResult
listWithFiltersErr
error
}
func
(
s
*
accountRepoStubForAdminList
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersCalls
++
s
.
listWithFiltersParams
=
params
s
.
listWithFiltersPlatform
=
platform
s
.
listWithFiltersType
=
accountType
s
.
listWithFiltersStatus
=
status
s
.
listWithFiltersSearch
=
search
if
s
.
listWithFiltersErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersErr
}
result
:=
s
.
listWithFiltersResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersAccounts
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersAccounts
,
result
,
nil
}
type
proxyRepoStubForAdminList
struct
{
proxyRepoStub
listWithFiltersCalls
int
listWithFiltersParams
pagination
.
PaginationParams
listWithFiltersProtocol
string
listWithFiltersStatus
string
listWithFiltersSearch
string
listWithFiltersProxies
[]
Proxy
listWithFiltersResult
*
pagination
.
PaginationResult
listWithFiltersErr
error
listWithFiltersAndAccountCountCalls
int
listWithFiltersAndAccountCountParams
pagination
.
PaginationParams
listWithFiltersAndAccountCountProtocol
string
listWithFiltersAndAccountCountStatus
string
listWithFiltersAndAccountCountSearch
string
listWithFiltersAndAccountCountProxies
[]
ProxyWithAccountCount
listWithFiltersAndAccountCountResult
*
pagination
.
PaginationResult
listWithFiltersAndAccountCountErr
error
}
func
(
s
*
proxyRepoStubForAdminList
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
protocol
,
status
,
search
string
)
([]
Proxy
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersCalls
++
s
.
listWithFiltersParams
=
params
s
.
listWithFiltersProtocol
=
protocol
s
.
listWithFiltersStatus
=
status
s
.
listWithFiltersSearch
=
search
if
s
.
listWithFiltersErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersErr
}
result
:=
s
.
listWithFiltersResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersProxies
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersProxies
,
result
,
nil
}
func
(
s
*
proxyRepoStubForAdminList
)
ListWithFiltersAndAccountCount
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
protocol
,
status
,
search
string
)
([]
ProxyWithAccountCount
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersAndAccountCountCalls
++
s
.
listWithFiltersAndAccountCountParams
=
params
s
.
listWithFiltersAndAccountCountProtocol
=
protocol
s
.
listWithFiltersAndAccountCountStatus
=
status
s
.
listWithFiltersAndAccountCountSearch
=
search
if
s
.
listWithFiltersAndAccountCountErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersAndAccountCountErr
}
result
:=
s
.
listWithFiltersAndAccountCountResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersAndAccountCountProxies
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersAndAccountCountProxies
,
result
,
nil
}
type
redeemRepoStubForAdminList
struct
{
redeemRepoStub
listWithFiltersCalls
int
listWithFiltersParams
pagination
.
PaginationParams
listWithFiltersType
string
listWithFiltersStatus
string
listWithFiltersSearch
string
listWithFiltersCodes
[]
RedeemCode
listWithFiltersResult
*
pagination
.
PaginationResult
listWithFiltersErr
error
}
func
(
s
*
redeemRepoStubForAdminList
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
codeType
,
status
,
search
string
)
([]
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersCalls
++
s
.
listWithFiltersParams
=
params
s
.
listWithFiltersType
=
codeType
s
.
listWithFiltersStatus
=
status
s
.
listWithFiltersSearch
=
search
if
s
.
listWithFiltersErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersErr
}
result
:=
s
.
listWithFiltersResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersCodes
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersCodes
,
result
,
nil
}
func
TestAdminService_ListAccounts_WithSearch
(
t
*
testing
.
T
)
{
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForAdminList
{
listWithFiltersAccounts
:
[]
Account
{{
ID
:
1
,
Name
:
"acc"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
10
},
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
accounts
,
total
,
err
:=
svc
.
ListAccounts
(
context
.
Background
(),
1
,
20
,
PlatformGemini
,
AccountTypeOAuth
,
StatusActive
,
"acc"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
),
total
)
require
.
Equal
(
t
,
[]
Account
{{
ID
:
1
,
Name
:
"acc"
}},
accounts
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
PlatformGemini
,
repo
.
listWithFiltersPlatform
)
require
.
Equal
(
t
,
AccountTypeOAuth
,
repo
.
listWithFiltersType
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"acc"
,
repo
.
listWithFiltersSearch
)
})
}
func
TestAdminService_ListProxies_WithSearch
(
t
*
testing
.
T
)
{
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
proxyRepoStubForAdminList
{
listWithFiltersProxies
:
[]
Proxy
{{
ID
:
2
,
Name
:
"p1"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
7
},
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
proxies
,
total
,
err
:=
svc
.
ListProxies
(
context
.
Background
(),
3
,
50
,
"http"
,
StatusActive
,
"p1"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
7
),
total
)
require
.
Equal
(
t
,
[]
Proxy
{{
ID
:
2
,
Name
:
"p1"
}},
proxies
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
3
,
PageSize
:
50
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
"http"
,
repo
.
listWithFiltersProtocol
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"p1"
,
repo
.
listWithFiltersSearch
)
})
}
func
TestAdminService_ListProxiesWithAccountCount_WithSearch
(
t
*
testing
.
T
)
{
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
proxyRepoStubForAdminList
{
listWithFiltersAndAccountCountProxies
:
[]
ProxyWithAccountCount
{{
Proxy
:
Proxy
{
ID
:
3
,
Name
:
"p2"
},
AccountCount
:
5
}},
listWithFiltersAndAccountCountResult
:
&
pagination
.
PaginationResult
{
Total
:
9
},
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
proxies
,
total
,
err
:=
svc
.
ListProxiesWithAccountCount
(
context
.
Background
(),
2
,
10
,
"socks5"
,
StatusDisabled
,
"p2"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
9
),
total
)
require
.
Equal
(
t
,
[]
ProxyWithAccountCount
{{
Proxy
:
Proxy
{
ID
:
3
,
Name
:
"p2"
},
AccountCount
:
5
}},
proxies
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersAndAccountCountCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
2
,
PageSize
:
10
},
repo
.
listWithFiltersAndAccountCountParams
)
require
.
Equal
(
t
,
"socks5"
,
repo
.
listWithFiltersAndAccountCountProtocol
)
require
.
Equal
(
t
,
StatusDisabled
,
repo
.
listWithFiltersAndAccountCountStatus
)
require
.
Equal
(
t
,
"p2"
,
repo
.
listWithFiltersAndAccountCountSearch
)
})
}
func
TestAdminService_ListRedeemCodes_WithSearch
(
t
*
testing
.
T
)
{
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
redeemRepoStubForAdminList
{
listWithFiltersCodes
:
[]
RedeemCode
{{
ID
:
4
,
Code
:
"ABC"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
3
},
}
svc
:=
&
adminServiceImpl
{
redeemCodeRepo
:
repo
}
codes
,
total
,
err
:=
svc
.
ListRedeemCodes
(
context
.
Background
(),
1
,
20
,
RedeemTypeBalance
,
StatusUnused
,
"ABC"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
3
),
total
)
require
.
Equal
(
t
,
[]
RedeemCode
{{
ID
:
4
,
Code
:
"ABC"
}},
codes
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
RedeemTypeBalance
,
repo
.
listWithFiltersType
)
require
.
Equal
(
t
,
StatusUnused
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"ABC"
,
repo
.
listWithFiltersSearch
)
})
}
backend/internal/service/antigravity_gateway_service.go
View file @
6b97a8be
...
...
@@ -10,6 +10,7 @@ import (
"io"
"log"
mathrand
"math/rand"
"net"
"net/http"
"strings"
"sync/atomic"
...
...
@@ -27,6 +28,32 @@ const (
antigravityRetryMaxDelay
=
16
*
time
.
Second
)
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
func
isAntigravityConnectionError
(
err
error
)
bool
{
if
err
==
nil
{
return
false
}
// 检查超时错误
var
netErr
net
.
Error
if
errors
.
As
(
err
,
&
netErr
)
&&
netErr
.
Timeout
()
{
return
true
}
// 检查连接错误(DNS 失败、连接拒绝)
var
opErr
*
net
.
OpError
return
errors
.
As
(
err
,
&
opErr
)
}
// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL
// 仅连接错误和 HTTP 429 触发 URL 降级
func
shouldAntigravityFallbackToNextURL
(
err
error
,
statusCode
int
)
bool
{
if
isAntigravityConnectionError
(
err
)
{
return
true
}
return
statusCode
==
http
.
StatusTooManyRequests
}
// getSessionID 从 gin.Context 获取 session_id(用于日志追踪)
func
getSessionID
(
c
*
gin
.
Context
)
string
{
if
c
==
nil
{
...
...
@@ -181,34 +208,56 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return
nil
,
fmt
.
Errorf
(
"构建请求失败: %w"
,
err
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// URL fallback 循环
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLs
()
if
len
(
availableURLs
)
==
0
{
availableURLs
=
antigravity
.
BaseURLs
// 所有 URL 都不可用时,重试所有
}
var
lastErr
error
for
urlIdx
,
baseURL
:=
range
availableURLs
{
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
req
,
err
:=
antigravity
.
NewAPIRequest
(
ctx
,
"streamGenerateContent"
,
accessToken
,
requestBody
)
req
,
err
:=
antigravity
.
NewAPIRequest
WithURL
(
ctx
,
baseURL
,
"streamGenerateContent"
,
accessToken
,
requestBody
)
if
err
!=
nil
{
return
nil
,
err
lastErr
=
err
continue
}
// 调试日志:Test 请求信息
log
.
Printf
(
"[antigravity-Test] account=%s request_size=%d url=%s"
,
account
.
Name
,
len
(
requestBody
),
req
.
URL
.
String
())
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"请求失败: %w"
,
err
)
lastErr
=
fmt
.
Errorf
(
"请求失败: %w"
,
err
)
if
shouldAntigravityFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
antigravity
.
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity-Test] URL fallback: %s -> %s"
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
}
return
nil
,
lastErr
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 读取响应
respBody
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
// 立即关闭,避免循环内 defer 导致的资源泄漏
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
// 检查是否需要 URL 降级
if
shouldAntigravityFallbackToNextURL
(
nil
,
resp
.
StatusCode
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
antigravity
.
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity-Test] URL fallback (HTTP %d): %s -> %s"
,
resp
.
StatusCode
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
}
if
resp
.
StatusCode
>=
400
{
return
nil
,
fmt
.
Errorf
(
"API 返回 %d: %s"
,
resp
.
StatusCode
,
string
(
respBody
))
}
...
...
@@ -220,6 +269,9 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
Text
:
text
,
MappedModel
:
mappedModel
,
},
nil
}
return
nil
,
lastErr
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
...
...
@@ -484,8 +536,16 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action
:=
"streamGenerateContent"
// URL fallback 循环
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLs
()
if
len
(
availableURLs
)
==
0
{
availableURLs
=
antigravity
.
BaseURLs
// 所有 URL 都不可用时,重试所有
}
// 重试循环
var
resp
*
http
.
Response
urlFallbackLoop
:
for
urlIdx
,
baseURL
:=
range
availableURLs
{
for
attempt
:=
1
;
attempt
<=
antigravityMaxRetries
;
attempt
++
{
// 检查 context 是否已取消(客户端断开连接)
select
{
...
...
@@ -495,13 +555,19 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
default
:
}
upstreamReq
,
err
:=
antigravity
.
NewAPIRequest
(
ctx
,
action
,
accessToken
,
geminiBody
)
upstreamReq
,
err
:=
antigravity
.
NewAPIRequest
WithURL
(
ctx
,
baseURL
,
action
,
accessToken
,
geminiBody
)
if
err
!=
nil
{
return
nil
,
err
}
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
// 检查是否应触发 URL 降级
if
shouldAntigravityFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
antigravity
.
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"%s URL fallback (connection error): %s -> %s"
,
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
urlFallbackLoop
}
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
prefix
,
attempt
,
antigravityMaxRetries
,
err
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
...
...
@@ -514,6 +580,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries"
)
}
// 检查是否应触发 URL 降级(仅 429)
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
antigravity
.
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"%s URL fallback (HTTP 429): %s -> %s body=%s"
,
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
],
truncateForLog
(
respBody
,
200
))
continue
urlFallbackLoop
}
if
resp
.
StatusCode
>=
400
&&
s
.
shouldRetryUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
...
...
@@ -536,10 +611,11 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Header
:
resp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
break
break
urlFallbackLoop
}
break
break
urlFallbackLoop
}
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
...
...
@@ -1003,8 +1079,16 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction
:=
"streamGenerateContent"
// URL fallback 循环
availableURLs
:=
antigravity
.
DefaultURLAvailability
.
GetAvailableURLs
()
if
len
(
availableURLs
)
==
0
{
availableURLs
=
antigravity
.
BaseURLs
// 所有 URL 都不可用时,重试所有
}
// 重试循环
var
resp
*
http
.
Response
urlFallbackLoop
:
for
urlIdx
,
baseURL
:=
range
availableURLs
{
for
attempt
:=
1
;
attempt
<=
antigravityMaxRetries
;
attempt
++
{
// 检查 context 是否已取消(客户端断开连接)
select
{
...
...
@@ -1014,13 +1098,19 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
default
:
}
upstreamReq
,
err
:=
antigravity
.
NewAPIRequest
(
ctx
,
upstreamAction
,
accessToken
,
wrappedBody
)
upstreamReq
,
err
:=
antigravity
.
NewAPIRequest
WithURL
(
ctx
,
baseURL
,
upstreamAction
,
accessToken
,
wrappedBody
)
if
err
!=
nil
{
return
nil
,
err
}
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
// 检查是否应触发 URL 降级
if
shouldAntigravityFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
antigravity
.
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"%s URL fallback (connection error): %s -> %s"
,
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
urlFallbackLoop
}
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
prefix
,
attempt
,
antigravityMaxRetries
,
err
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
...
...
@@ -1033,6 +1123,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadGateway
,
"Upstream request failed after retries"
)
}
// 检查是否应触发 URL 降级(仅 429)
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
antigravity
.
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"%s URL fallback (HTTP 429): %s -> %s body=%s"
,
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
],
truncateForLog
(
respBody
,
200
))
continue
urlFallbackLoop
}
if
resp
.
StatusCode
>=
400
&&
s
.
shouldRetryUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
...
...
@@ -1054,10 +1153,11 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Header
:
resp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
break
break
urlFallbackLoop
}
break
break
urlFallbackLoop
}
}
defer
func
()
{
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
...
...
backend/internal/service/auth_service.go
View file @
6b97a8be
...
...
@@ -2,9 +2,13 @@ package service
import
(
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"log"
"net/mail"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -18,6 +22,7 @@ var (
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
...
...
@@ -75,23 +80,32 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册
if
s
.
settingService
!
=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
// 检查是否开放注册
(默认关闭:settingService 未配置时不允许注册)
if
s
.
settingService
=
=
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
}
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
if
isReservedEmail
(
email
)
{
return
""
,
nil
,
ErrEmailReserved
}
// 检查是否需要邮件验证
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
// 这是一个配置错误,不应该允许绕过验证
if
s
.
emailService
==
nil
{
log
.
Println
(
"[Auth] Email verification enabled but email service not configured, rejecting registration"
)
return
""
,
nil
,
ErrServiceUnavailable
}
if
verifyCode
==
""
{
return
""
,
nil
,
ErrEmailVerifyRequired
}
// 验证邮箱验证码
if
s
.
emailService
!=
nil
{
if
err
:=
s
.
emailService
.
VerifyCode
(
ctx
,
email
,
verifyCode
);
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"verify code: %w"
,
err
)
}
}
}
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
...
...
@@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
user
);
err
!=
nil
{
// 优先检查邮箱冲突错误(竞态条件下可能发生)
if
errors
.
Is
(
err
,
ErrEmailExists
)
{
return
""
,
nil
,
ErrEmailExists
}
log
.
Printf
(
"[Auth] Database error creating user: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
...
...
@@ -148,11 +166,15 @@ type SendVerifyCodeResult struct {
// SendVerifyCode 发送邮箱验证码(同步方式)
func
(
s
*
AuthService
)
SendVerifyCode
(
ctx
context
.
Context
,
email
string
)
error
{
// 检查是否开放注册
if
s
.
settingService
!
=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
// 检查是否开放注册
(默认关闭)
if
s
.
settingService
=
=
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
ErrRegDisabled
}
if
isReservedEmail
(
email
)
{
return
ErrEmailReserved
}
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
...
...
@@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
func
(
s
*
AuthService
)
SendVerifyCodeAsync
(
ctx
context
.
Context
,
email
string
)
(
*
SendVerifyCodeResult
,
error
)
{
log
.
Printf
(
"[Auth] SendVerifyCodeAsync called for email: %s"
,
email
)
// 检查是否开放注册
if
s
.
settingService
!
=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
// 检查是否开放注册
(默认关闭)
if
s
.
settingService
=
=
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
log
.
Println
(
"[Auth] Registration is disabled"
)
return
nil
,
ErrRegDisabled
}
if
isReservedEmail
(
email
)
{
return
nil
,
ErrEmailReserved
}
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
...
...
@@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
// IsRegistrationEnabled 检查是否开放注册
func
(
s
*
AuthService
)
IsRegistrationEnabled
(
ctx
context
.
Context
)
bool
{
if
s
.
settingService
==
nil
{
return
true
return
false
// 安全默认:settingService 未配置时关闭注册
}
return
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
}
...
...
@@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
return
token
,
user
,
nil
}
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
// - 如果邮箱已存在:直接登录(不需要本地密码)
// - 如果邮箱不存在:创建新用户并登录
//
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
func
(
s
*
AuthService
)
LoginOrRegisterOAuth
(
ctx
context
.
Context
,
email
,
username
string
)
(
string
,
*
User
,
error
)
{
email
=
strings
.
TrimSpace
(
email
)
if
email
==
""
||
len
(
email
)
>
255
{
return
""
,
nil
,
infraerrors
.
BadRequest
(
"INVALID_EMAIL"
,
"invalid email"
)
}
if
_
,
err
:=
mail
.
ParseAddress
(
email
);
err
!=
nil
{
return
""
,
nil
,
infraerrors
.
BadRequest
(
"INVALID_EMAIL"
,
"invalid email"
)
}
username
=
strings
.
TrimSpace
(
username
)
if
len
([]
rune
(
username
))
>
100
{
username
=
string
([]
rune
(
username
)[
:
100
])
}
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
// OAuth 首次登录视为注册。
if
s
.
settingService
!=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
}
randomPassword
,
err
:=
randomHexString
(
32
)
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to generate random password for oauth signup: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
hashedPassword
,
err
:=
s
.
HashPassword
(
randomPassword
)
if
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"hash password: %w"
,
err
)
}
// 新用户默认值。
defaultBalance
:=
s
.
cfg
.
Default
.
UserBalance
defaultConcurrency
:=
s
.
cfg
.
Default
.
UserConcurrency
if
s
.
settingService
!=
nil
{
defaultBalance
=
s
.
settingService
.
GetDefaultBalance
(
ctx
)
defaultConcurrency
=
s
.
settingService
.
GetDefaultConcurrency
(
ctx
)
}
newUser
:=
&
User
{
Email
:
email
,
Username
:
username
,
PasswordHash
:
hashedPassword
,
Role
:
RoleUser
,
Balance
:
defaultBalance
,
Concurrency
:
defaultConcurrency
,
Status
:
StatusActive
,
}
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
newUser
);
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrEmailExists
)
{
// 并发场景:GetByEmail 与 Create 之间用户被创建。
user
,
err
=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error getting user after conflict: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
}
else
{
log
.
Printf
(
"[Auth] Database error creating oauth user: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
}
else
{
user
=
newUser
}
}
else
{
log
.
Printf
(
"[Auth] Database error during oauth login: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
}
if
!
user
.
IsActive
()
{
return
""
,
nil
,
ErrUserNotActive
}
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
if
user
.
Username
==
""
&&
username
!=
""
{
user
.
Username
=
username
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to update username after oauth login: %v"
,
err
)
}
}
token
,
err
:=
s
.
GenerateToken
(
user
)
if
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
return
token
,
user
,
nil
}
// ValidateToken 验证JWT token并返回用户声明
func
(
s
*
AuthService
)
ValidateToken
(
tokenString
string
)
(
*
JWTClaims
,
error
)
{
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
...
...
@@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
if
err
!=
nil
{
if
errors
.
Is
(
err
,
jwt
.
ErrTokenExpired
)
{
// token 过期但仍返回 claims(用于 RefreshToken 等场景)
// jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充
if
claims
,
ok
:=
token
.
Claims
.
(
*
JWTClaims
);
ok
{
return
claims
,
ErrTokenExpired
}
return
nil
,
ErrTokenExpired
}
return
nil
,
ErrInvalidToken
...
...
@@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
return
nil
,
ErrInvalidToken
}
func
randomHexString
(
byteLength
int
)
(
string
,
error
)
{
if
byteLength
<=
0
{
byteLength
=
16
}
buf
:=
make
([]
byte
,
byteLength
)
if
_
,
err
:=
rand
.
Read
(
buf
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
buf
),
nil
}
func
isReservedEmail
(
email
string
)
bool
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
email
))
return
strings
.
HasSuffix
(
normalized
,
LinuxDoConnectSyntheticEmailDomain
)
}
// GenerateToken 生成JWT token
func
(
s
*
AuthService
)
GenerateToken
(
user
*
User
)
(
string
,
error
)
{
now
:=
time
.
Now
()
...
...
backend/internal/service/auth_service_register_test.go
View file @
6b97a8be
...
...
@@ -113,13 +113,36 @@ func TestAuthService_Register_Disabled(t *testing.T) {
require
.
ErrorIs
(
t
,
err
,
ErrRegDisabled
)
}
func
TestAuthService_Register_EmailVerifyRequired
(
t
*
testing
.
T
)
{
func
TestAuthService_Register_DisabledByDefault
(
t
*
testing
.
T
)
{
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
repo
:=
&
userRepoStub
{}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrRegDisabled
)
}
func
TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{}
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
nil
)
// 应返回服务不可用错误,而不是允许绕过验证
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
func
TestAuthService_Register_EmailVerifyRequired
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{}
cache
:=
&
emailCacheStub
{}
// 配置 emailService
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
}
...
...
@@ -141,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
func
TestAuthService_Register_EmailExists
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{
exists
:
true
}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailExists
)
...
...
@@ -149,23 +174,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
func
TestAuthService_Register_CheckEmailError
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{
existsErr
:
errors
.
New
(
"db down"
)}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
func
TestAuthService_Register_ReservedEmail
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{}
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"linuxdo-123@linuxdo-connect.invalid"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailReserved
)
}
func
TestAuthService_Register_CreateError
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{
createErr
:
errors
.
New
(
"create failed"
)}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
func
TestAuthService_Register_CreateEmailExistsRace
(
t
*
testing
.
T
)
{
// 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败
repo
:=
&
userRepoStub
{
createErr
:
ErrEmailExists
}
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailExists
)
}
func
TestAuthService_Register_Success
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{
nextID
:
5
}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
token
,
user
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
NoError
(
t
,
err
)
...
...
@@ -180,3 +232,63 @@ func TestAuthService_Register_Success(t *testing.T) {
require
.
Len
(
t
,
repo
.
created
,
1
)
require
.
True
(
t
,
user
.
CheckPassword
(
"password"
))
}
func
TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
// 创建用户并生成 token
user
:=
&
User
{
ID
:
1
,
Email
:
"test@test.com"
,
Role
:
RoleUser
,
Status
:
StatusActive
,
TokenVersion
:
1
,
}
token
,
err
:=
service
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
// 验证有效 token
claims
,
err
:=
service
.
ValidateToken
(
token
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
claims
)
require
.
Equal
(
t
,
int64
(
1
),
claims
.
UserID
)
// 模拟过期 token(通过创建一个过期很久的 token)
service
.
cfg
.
JWT
.
ExpireHour
=
-
1
// 设置为负数使 token 立即过期
expiredToken
,
err
:=
service
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
service
.
cfg
.
JWT
.
ExpireHour
=
1
// 恢复
// 验证过期 token 应返回 claims 和 ErrTokenExpired
claims
,
err
=
service
.
ValidateToken
(
expiredToken
)
require
.
ErrorIs
(
t
,
err
,
ErrTokenExpired
)
require
.
NotNil
(
t
,
claims
,
"claims should not be nil when token is expired"
)
require
.
Equal
(
t
,
int64
(
1
),
claims
.
UserID
)
require
.
Equal
(
t
,
"test@test.com"
,
claims
.
Email
)
}
func
TestAuthService_RefreshToken_ExpiredTokenNoPanic
(
t
*
testing
.
T
)
{
user
:=
&
User
{
ID
:
1
,
Email
:
"test@test.com"
,
Role
:
RoleUser
,
Status
:
StatusActive
,
TokenVersion
:
1
,
}
repo
:=
&
userRepoStub
{
user
:
user
}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
// 创建过期 token
service
.
cfg
.
JWT
.
ExpireHour
=
-
1
expiredToken
,
err
:=
service
.
GenerateToken
(
user
)
require
.
NoError
(
t
,
err
)
service
.
cfg
.
JWT
.
ExpireHour
=
1
// RefreshToken 使用过期 token 不应 panic
require
.
NotPanics
(
t
,
func
()
{
newToken
,
err
:=
service
.
RefreshToken
(
context
.
Background
(),
expiredToken
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
newToken
)
})
}
backend/internal/service/domain_constants.go
View file @
6b97a8be
...
...
@@ -105,7 +105,17 @@ const (
// Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch
=
"enable_identity_patch"
SettingKeyIdentityPatchPrompt
=
"identity_patch_prompt"
// LinuxDo Connect OAuth 登录(终端用户 SSO)
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientSecret
=
"linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
)
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
const
LinuxDoConnectSyntheticEmailDomain
=
"@linuxdo-connect.invalid"
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const
AdminAPIKeyPrefix
=
"admin-"
backend/internal/service/email_service.go
View file @
6b97a8be
...
...
@@ -5,6 +5,7 @@ import (
"crypto/rand"
"crypto/tls"
"fmt"
"log"
"math/big"
"net/smtp"
"strconv"
...
...
@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配
if
data
.
Code
!=
code
{
data
.
Attempts
++
_
=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
)
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
}
if
data
.
Attempts
>=
maxVerifyCodeAttempts
{
return
ErrVerifyCodeMaxAttempts
}
...
...
@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
}
// 验证成功,删除验证码
_
=
s
.
cache
.
DeleteVerificationCode
(
ctx
,
email
)
if
err
:=
s
.
cache
.
DeleteVerificationCode
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to delete verification code after success: %v"
,
err
)
}
return
nil
}
...
...
backend/internal/service/gemini_multiplatform_test.go
View file @
6b97a8be
...
...
@@ -166,7 +166,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([
func
(
m
*
mockGroupRepoForGemini
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
m
*
mockGroupRepoForGemini
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGemini
)
ListActive
(
ctx
context
.
Context
)
([]
Group
,
error
)
{
return
nil
,
nil
}
...
...
backend/internal/service/gemini_oauth_service.go
View file @
6b97a8be
...
...
@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
}
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
, regardless of configured client_id/secret.
// - google_one:
uses configured OAuth client when provided; otherwise falls back to built-in client.
// - ai_studio: requires a user-provided OAuth client
.
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one:
always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfg
:=
geminicli
.
OAuthConfig
{
ClientID
:
s
.
cfg
.
Gemini
.
OAuth
.
ClientID
,
ClientSecret
:
s
.
cfg
.
Gemini
.
OAuth
.
ClientSecret
,
Scopes
:
s
.
cfg
.
Gemini
.
OAuth
.
Scopes
,
}
if
oauthType
==
"code_assist"
{
if
oauthType
==
"code_assist"
||
oauthType
==
"google_one"
{
// Force use of built-in Gemini CLI OAuth client
oauthCfg
.
ClientID
=
""
oauthCfg
.
ClientSecret
=
""
}
...
...
@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
case
"google_one"
:
log
.
Printf
(
"[GeminiOAuth] Processing google_one OAuth type"
)
// Google One accounts use cloudaicompanion API, which requires a project_id.
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
if
projectID
==
""
{
log
.
Printf
(
"[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API..."
)
var
err
error
projectID
,
_
,
err
=
s
.
fetchProjectID
(
ctx
,
tokenResp
.
AccessToken
,
proxyURL
)
if
err
!=
nil
{
log
.
Printf
(
"[GeminiOAuth] ERROR: Failed to fetch project_id: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"google One accounts require a project_id, failed to auto-detect: %w"
,
err
)
}
log
.
Printf
(
"[GeminiOAuth] Successfully fetched project_id: %s"
,
projectID
)
}
log
.
Printf
(
"[GeminiOAuth] Attempting to fetch Google One tier from Drive API..."
)
// Attempt to fetch Drive storage tier
var
storageInfo
*
geminicli
.
DriveStorageInfo
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment