Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
f6f072cb
Commit
f6f072cb
authored
Jan 10, 2026
by
Edric Li
Browse files
Merge branch 'main' into feat/api-key-ip-restriction
parents
5265b12c
ff087586
Changes
108
Show whitespace changes
Inline
Side-by-side
backend/internal/repository/group_repo_integration_test.go
View file @
f6f072cb
...
...
@@ -4,6 +4,8 @@ package repository
import
(
"context"
"database/sql"
"errors"
"testing"
dbent
"github.com/Wei-Shaw/sub2api/ent"
...
...
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
repo
*
groupRepository
}
type
forbidSQLExecutor
struct
{
called
bool
}
func
(
s
*
forbidSQLExecutor
)
ExecContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
sql
.
Result
,
error
)
{
s
.
called
=
true
return
nil
,
errors
.
New
(
"unexpected sql exec"
)
}
func
(
s
*
forbidSQLExecutor
)
QueryContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
*
sql
.
Rows
,
error
)
{
s
.
called
=
true
return
nil
,
errors
.
New
(
"unexpected sql query"
)
}
func
(
s
*
GroupRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
tx
:=
testEntTx
(
s
.
T
())
...
...
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrGroupNotFound
)
}
func
(
s
*
GroupRepoSuite
)
TestGetByIDLite_DoesNotUseAccountCount
()
{
group
:=
&
service
.
Group
{
Name
:
"lite-group"
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.0
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}
s
.
Require
()
.
NoError
(
s
.
repo
.
Create
(
s
.
ctx
,
group
))
spy
:=
&
forbidSQLExecutor
{}
repo
:=
newGroupRepositoryWithSQL
(
s
.
tx
.
Client
(),
spy
)
got
,
err
:=
repo
.
GetByIDLite
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
ID
)
s
.
Require
()
.
False
(
spy
.
called
,
"expected no direct sql executor usage"
)
}
func
(
s
*
GroupRepoSuite
)
TestUpdate
()
{
group
:=
&
service
.
Group
{
Name
:
"original"
,
...
...
backend/internal/repository/promo_code_repo.go
0 → 100644
View file @
f6f072cb
package
repository
import
(
"context"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type
promoCodeRepository
struct
{
client
*
dbent
.
Client
}
func
NewPromoCodeRepository
(
client
*
dbent
.
Client
)
service
.
PromoCodeRepository
{
return
&
promoCodeRepository
{
client
:
client
}
}
func
(
r
*
promoCodeRepository
)
Create
(
ctx
context
.
Context
,
code
*
service
.
PromoCode
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
builder
:=
client
.
PromoCode
.
Create
()
.
SetCode
(
code
.
Code
)
.
SetBonusAmount
(
code
.
BonusAmount
)
.
SetMaxUses
(
code
.
MaxUses
)
.
SetUsedCount
(
code
.
UsedCount
)
.
SetStatus
(
code
.
Status
)
.
SetNotes
(
code
.
Notes
)
if
code
.
ExpiresAt
!=
nil
{
builder
.
SetExpiresAt
(
*
code
.
ExpiresAt
)
}
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
return
err
}
code
.
ID
=
created
.
ID
code
.
CreatedAt
=
created
.
CreatedAt
code
.
UpdatedAt
=
created
.
UpdatedAt
return
nil
}
func
(
r
*
promoCodeRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
PromoCode
,
error
)
{
m
,
err
:=
r
.
client
.
PromoCode
.
Query
()
.
Where
(
promocode
.
IDEQ
(
id
))
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
service
.
ErrPromoCodeNotFound
}
return
nil
,
err
}
return
promoCodeEntityToService
(
m
),
nil
}
func
(
r
*
promoCodeRepository
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
service
.
PromoCode
,
error
)
{
m
,
err
:=
r
.
client
.
PromoCode
.
Query
()
.
Where
(
promocode
.
CodeEqualFold
(
code
))
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
service
.
ErrPromoCodeNotFound
}
return
nil
,
err
}
return
promoCodeEntityToService
(
m
),
nil
}
func
(
r
*
promoCodeRepository
)
GetByCodeForUpdate
(
ctx
context
.
Context
,
code
string
)
(
*
service
.
PromoCode
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
m
,
err
:=
client
.
PromoCode
.
Query
()
.
Where
(
promocode
.
CodeEqualFold
(
code
))
.
ForUpdate
()
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
service
.
ErrPromoCodeNotFound
}
return
nil
,
err
}
return
promoCodeEntityToService
(
m
),
nil
}
func
(
r
*
promoCodeRepository
)
Update
(
ctx
context
.
Context
,
code
*
service
.
PromoCode
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
builder
:=
client
.
PromoCode
.
UpdateOneID
(
code
.
ID
)
.
SetCode
(
code
.
Code
)
.
SetBonusAmount
(
code
.
BonusAmount
)
.
SetMaxUses
(
code
.
MaxUses
)
.
SetUsedCount
(
code
.
UsedCount
)
.
SetStatus
(
code
.
Status
)
.
SetNotes
(
code
.
Notes
)
if
code
.
ExpiresAt
!=
nil
{
builder
.
SetExpiresAt
(
*
code
.
ExpiresAt
)
}
else
{
builder
.
ClearExpiresAt
()
}
updated
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
service
.
ErrPromoCodeNotFound
}
return
err
}
code
.
UpdatedAt
=
updated
.
UpdatedAt
return
nil
}
func
(
r
*
promoCodeRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
PromoCode
.
Delete
()
.
Where
(
promocode
.
IDEQ
(
id
))
.
Exec
(
ctx
)
return
err
}
func
(
r
*
promoCodeRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
PromoCode
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
)
}
func
(
r
*
promoCodeRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
search
string
)
([]
service
.
PromoCode
,
*
pagination
.
PaginationResult
,
error
)
{
q
:=
r
.
client
.
PromoCode
.
Query
()
if
status
!=
""
{
q
=
q
.
Where
(
promocode
.
StatusEQ
(
status
))
}
if
search
!=
""
{
q
=
q
.
Where
(
promocode
.
CodeContainsFold
(
search
))
}
total
,
err
:=
q
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
codes
,
err
:=
q
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
dbent
.
Desc
(
promocode
.
FieldID
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
outCodes
:=
promoCodeEntitiesToService
(
codes
)
return
outCodes
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
promoCodeRepository
)
CreateUsage
(
ctx
context
.
Context
,
usage
*
service
.
PromoCodeUsage
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
created
,
err
:=
client
.
PromoCodeUsage
.
Create
()
.
SetPromoCodeID
(
usage
.
PromoCodeID
)
.
SetUserID
(
usage
.
UserID
)
.
SetBonusAmount
(
usage
.
BonusAmount
)
.
SetUsedAt
(
usage
.
UsedAt
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
err
}
usage
.
ID
=
created
.
ID
return
nil
}
func
(
r
*
promoCodeRepository
)
GetUsageByPromoCodeAndUser
(
ctx
context
.
Context
,
promoCodeID
,
userID
int64
)
(
*
service
.
PromoCodeUsage
,
error
)
{
m
,
err
:=
r
.
client
.
PromoCodeUsage
.
Query
()
.
Where
(
promocodeusage
.
PromoCodeIDEQ
(
promoCodeID
),
promocodeusage
.
UserIDEQ
(
userID
),
)
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
nil
}
return
nil
,
err
}
return
promoCodeUsageEntityToService
(
m
),
nil
}
func
(
r
*
promoCodeRepository
)
ListUsagesByPromoCode
(
ctx
context
.
Context
,
promoCodeID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
PromoCodeUsage
,
*
pagination
.
PaginationResult
,
error
)
{
q
:=
r
.
client
.
PromoCodeUsage
.
Query
()
.
Where
(
promocodeusage
.
PromoCodeIDEQ
(
promoCodeID
))
total
,
err
:=
q
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
usages
,
err
:=
q
.
WithUser
()
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
dbent
.
Desc
(
promocodeusage
.
FieldID
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
outUsages
:=
promoCodeUsageEntitiesToService
(
usages
)
return
outUsages
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
promoCodeRepository
)
IncrementUsedCount
(
ctx
context
.
Context
,
id
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
PromoCode
.
UpdateOneID
(
id
)
.
AddUsedCount
(
1
)
.
Save
(
ctx
)
return
err
}
// Entity to Service conversions
func
promoCodeEntityToService
(
m
*
dbent
.
PromoCode
)
*
service
.
PromoCode
{
if
m
==
nil
{
return
nil
}
return
&
service
.
PromoCode
{
ID
:
m
.
ID
,
Code
:
m
.
Code
,
BonusAmount
:
m
.
BonusAmount
,
MaxUses
:
m
.
MaxUses
,
UsedCount
:
m
.
UsedCount
,
Status
:
m
.
Status
,
ExpiresAt
:
m
.
ExpiresAt
,
Notes
:
derefString
(
m
.
Notes
),
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
}
}
func
promoCodeEntitiesToService
(
models
[]
*
dbent
.
PromoCode
)
[]
service
.
PromoCode
{
out
:=
make
([]
service
.
PromoCode
,
0
,
len
(
models
))
for
i
:=
range
models
{
if
s
:=
promoCodeEntityToService
(
models
[
i
]);
s
!=
nil
{
out
=
append
(
out
,
*
s
)
}
}
return
out
}
func
promoCodeUsageEntityToService
(
m
*
dbent
.
PromoCodeUsage
)
*
service
.
PromoCodeUsage
{
if
m
==
nil
{
return
nil
}
out
:=
&
service
.
PromoCodeUsage
{
ID
:
m
.
ID
,
PromoCodeID
:
m
.
PromoCodeID
,
UserID
:
m
.
UserID
,
BonusAmount
:
m
.
BonusAmount
,
UsedAt
:
m
.
UsedAt
,
}
if
m
.
Edges
.
User
!=
nil
{
out
.
User
=
userEntityToService
(
m
.
Edges
.
User
)
}
return
out
}
func
promoCodeUsageEntitiesToService
(
models
[]
*
dbent
.
PromoCodeUsage
)
[]
service
.
PromoCodeUsage
{
out
:=
make
([]
service
.
PromoCodeUsage
,
0
,
len
(
models
))
for
i
:=
range
models
{
if
s
:=
promoCodeUsageEntityToService
(
models
[
i
]);
s
!=
nil
{
out
=
append
(
out
,
*
s
)
}
}
return
out
}
backend/internal/repository/wire.go
View file @
f6f072cb
...
...
@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet(
NewAccountRepository
,
NewProxyRepository
,
NewRedeemCodeRepository
,
NewPromoCodeRepository
,
NewUsageLogRepository
,
NewSettingRepository
,
NewUserSubscriptionRepository
,
...
...
backend/internal/server/api_contract_test.go
View file @
f6f072cb
...
...
@@ -398,7 +398,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
)
...
...
@@ -575,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
return
nil
,
service
.
ErrGroupNotFound
}
func
(
stubGroupRepo
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
return
nil
,
service
.
ErrGroupNotFound
}
func
(
stubGroupRepo
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/http.go
View file @
f6f072cb
...
...
@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
// ProviderSet 提供服务器层的依赖
...
...
@@ -31,6 +32,7 @@ func ProvideRouter(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
settingService
*
service
.
SettingService
,
redisClient
*
redis
.
Client
,
)
*
gin
.
Engine
{
if
cfg
.
Server
.
Mode
==
"release"
{
gin
.
SetMode
(
gin
.
ReleaseMode
)
...
...
@@ -48,7 +50,7 @@ func ProvideRouter(
}
}
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
settingService
,
cfg
)
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
settingService
,
cfg
,
redisClient
)
}
// ProvideHTTPServer 提供 HTTP 服务器
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
f6f072cb
package
middleware
import
(
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
return
}
...
...
@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
}
...
...
@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
subscription
,
ok
:=
value
.
(
*
service
.
UserSubscription
)
return
subscription
,
ok
}
func
setGroupContext
(
c
*
gin
.
Context
,
group
*
service
.
Group
)
{
if
!
service
.
IsGroupContextValid
(
group
)
{
return
}
if
existing
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
&&
service
.
IsGroupContextValid
(
existing
)
{
return
}
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
Group
,
group
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
backend/internal/server/middleware/api_key_auth_google.go
View file @
f6f072cb
...
...
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
return
}
...
...
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
}
}
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
f6f072cb
...
...
@@ -9,6 +9,7 @@ import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
require
.
Equal
(
t
,
"INVALID_ARGUMENT"
,
resp
.
Error
.
Status
)
}
func
TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
99
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformGemini
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyService
:=
service
.
NewAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
},
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
},
)
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
:=
gin
.
New
()
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
f6f072cb
...
...
@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
...
...
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
ID
:
42
,
Name
:
"sub"
,
Status
:
service
.
StatusActive
,
Hydrated
:
true
,
SubscriptionType
:
service
.
SubscriptionTypeSubscription
,
DailyLimitUSD
:
&
limit
,
}
...
...
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func
TestAPIKeyAuthSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAPIKeyAuthOverwritesInvalidContextGroup
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
invalidGroup
:=
&
service
.
Group
{
ID
:
group
.
ID
,
Platform
:
group
.
Platform
,
Status
:
group
.
Status
,
}
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
||
!
groupFromCtx
.
Hydrated
||
groupFromCtx
==
invalidGroup
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
Group
,
invalidGroup
))
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
...
...
backend/internal/server/router.go
View file @
f6f072cb
...
...
@@ -11,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SetupRouter 配置路由器中间件和路由
...
...
@@ -24,6 +25,7 @@ func SetupRouter(
subscriptionService
*
service
.
SubscriptionService
,
settingService
*
service
.
SettingService
,
cfg
*
config
.
Config
,
redisClient
*
redis
.
Client
,
)
*
gin
.
Engine
{
// 应用中间件
r
.
Use
(
middleware2
.
Logger
())
...
...
@@ -44,7 +46,7 @@ func SetupRouter(
}
// 注册路由
registerRoutes
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
registerRoutes
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
,
redisClient
)
return
r
}
...
...
@@ -59,6 +61,7 @@ func registerRoutes(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
,
redisClient
*
redis
.
Client
,
)
{
// 通用路由(健康检查、状态等)
routes
.
RegisterCommonRoutes
(
r
)
...
...
@@ -67,7 +70,7 @@ func registerRoutes(
v1
:=
r
.
Group
(
"/api/v1"
)
// 注册各模块路由
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
,
redisClient
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
...
...
backend/internal/server/routes/admin.go
View file @
f6f072cb
...
...
@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
// 卡密管理
registerRedeemCodeRoutes
(
admin
,
h
)
// 优惠码管理
registerPromoCodeRoutes
(
admin
,
h
)
// 系统设置
registerSettingsRoutes
(
admin
,
h
)
...
...
@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func
registerPromoCodeRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
promoCodes
:=
admin
.
Group
(
"/promo-codes"
)
{
promoCodes
.
GET
(
""
,
h
.
Admin
.
Promo
.
List
)
promoCodes
.
GET
(
"/:id"
,
h
.
Admin
.
Promo
.
GetByID
)
promoCodes
.
POST
(
""
,
h
.
Admin
.
Promo
.
Create
)
promoCodes
.
PUT
(
"/:id"
,
h
.
Admin
.
Promo
.
Update
)
promoCodes
.
DELETE
(
"/:id"
,
h
.
Admin
.
Promo
.
Delete
)
promoCodes
.
GET
(
"/:id/usages"
,
h
.
Admin
.
Promo
.
GetUsages
)
}
}
func
registerSettingsRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
adminSettings
:=
admin
.
Group
(
"/settings"
)
{
...
...
backend/internal/server/routes/auth.go
View file @
f6f072cb
package
routes
import
(
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RegisterAuthRoutes 注册认证相关路由
func
RegisterAuthRoutes
(
v1
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
,
jwtAuth
middleware
.
JWTAuthMiddleware
,
jwtAuth
servermiddleware
.
JWTAuthMiddleware
,
redisClient
*
redis
.
Client
,
)
{
// 创建速率限制器
rateLimiter
:=
middleware
.
NewRateLimiter
(
redisClient
)
// 公开接口
auth
:=
v1
.
Group
(
"/auth"
)
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
// 优惠码验证接口添加速率限制:每分钟最多 10 次
auth
.
POST
(
"/validate-promo-code"
,
rateLimiter
.
Limit
(
"validate-promo"
,
10
,
time
.
Minute
),
h
.
Auth
.
ValidatePromoCode
)
auth
.
GET
(
"/oauth/linuxdo/start"
,
h
.
Auth
.
LinuxDoOAuthStart
)
auth
.
GET
(
"/oauth/linuxdo/callback"
,
h
.
Auth
.
LinuxDoOAuthCallback
)
}
...
...
backend/internal/service/account_service.go
View file @
f6f072cb
...
...
@@ -49,10 +49,12 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
ClearTempUnschedulable
(
ctx
context
.
Context
,
id
int64
)
error
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
AccountBulkUpdate
)
(
int64
,
error
)
...
...
backend/internal/service/account_service_delete_test.go
View file @
f6f072cb
...
...
@@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic
(
"unexpected SetRateLimited call"
)
}
func
(
s
*
accountRepoStub
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
panic
(
"unexpected SetAntigravityQuotaScopeLimit call"
)
}
func
(
s
*
accountRepoStub
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
panic
(
"unexpected SetOverloaded call"
)
}
...
...
@@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic
(
"unexpected ClearRateLimit call"
)
}
func
(
s
*
accountRepoStub
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
panic
(
"unexpected ClearAntigravityQuotaScopes call"
)
}
func
(
s
*
accountRepoStub
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
panic
(
"unexpected UpdateSessionWindow call"
)
}
...
...
backend/internal/service/admin_service.go
View file @
f6f072cb
...
...
@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return
fmt
.
Errorf
(
"cannot set self as fallback group"
)
}
visited
:=
map
[
int64
]
struct
{}{}
nextID
:=
fallbackGroupID
for
{
if
_
,
seen
:=
visited
[
nextID
];
seen
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
nextID
]
=
struct
{}{}
if
currentGroupID
>
0
&&
nextID
==
currentGroupID
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
// 检查降级分组是否存在
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
fallbackGroup
ID
)
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
Lite
(
ctx
,
next
ID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
}
// 降级分组不能启用 claude_code_only,否则会造成死循环
if
fallbackGroup
.
ClaudeCodeOnly
{
if
nextID
==
fallbackGroupID
&&
fallbackGroup
.
ClaudeCodeOnly
{
return
fmt
.
Errorf
(
"fallback group cannot have claude_code_only enabled"
)
}
if
fallbackGroup
.
FallbackGroupID
==
nil
{
return
nil
}
nextID
=
*
fallbackGroup
.
FallbackGroupID
}
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
...
...
backend/internal/service/admin_service_delete_test.go
View file @
f6f072cb
...
...
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic
(
"unexpected GetByID call"
)
}
func
(
s
*
groupRepoStub
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
panic
(
"unexpected GetByIDLite call"
)
}
func
(
s
*
groupRepoStub
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
...
...
backend/internal/service/admin_service_group_test.go
View file @
f6f072cb
...
...
@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return
s
.
getByID
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
GetByIDLite
(
_
context
.
Context
,
_
int64
)
(
*
Group
,
error
)
{
if
s
.
getErr
!=
nil
{
return
nil
,
s
.
getErr
}
return
s
.
getByID
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
...
...
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
require
.
True
(
t
,
*
repo
.
listWithFiltersIsExclusive
)
})
}
func
TestAdminService_ValidateFallbackGroup_DetectsCycle
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
fallbackID
:=
int64
(
2
)
repo
:=
&
groupRepoStubForFallbackCycle
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
FallbackGroupID
:
&
fallbackID
,
},
fallbackID
:
{
ID
:
fallbackID
,
FallbackGroupID
:
&
groupID
,
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
err
:=
svc
.
validateFallbackGroup
(
context
.
Background
(),
groupID
,
fallbackID
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
type
groupRepoStubForFallbackCycle
struct
{
groups
map
[
int64
]
*
Group
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Create
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Create call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Update
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
GetByIDLite
(
ctx
,
id
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
s
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteCascade
(
_
context
.
Context
,
_
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected DeleteCascade call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
List
(
_
context
.
Context
,
_
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActive call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActiveByPlatform
(
_
context
.
Context
,
_
string
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActiveByPlatform call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ExistsByName
(
_
context
.
Context
,
_
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByName call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountCount
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected GetAccountCount call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
backend/internal/service/antigravity_gateway_service.go
View file @
f6f072cb
...
...
@@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct {
// 长前缀优先
{
"gemini-2.5-flash-image"
,
"gemini-3-pro-image"
},
// gemini-2.5-flash-image → 3-pro-image
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"gemini-3-flash"
,
"gemini-3-flash"
},
// gemini-3-flash-preview 等 → gemini-3-flash
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"claude-haiku-4-5"
,
"claude-sonnet-4-5"
},
// claude-haiku-4-5-xxx → sonnet
...
...
@@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel
:=
claudeReq
.
Model
mappedModel
:=
s
.
getMappedModel
(
account
,
claudeReq
.
Model
)
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
...
...
@@ -603,7 +605,7 @@ urlFallbackLoop:
}
// 所有重试都失败,标记限流状态
if
resp
.
StatusCode
==
429
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
}
// 最后一次尝试也失败
resp
=
&
http
.
Response
{
...
...
@@ -696,7 +698,7 @@ urlFallbackLoop:
// 处理错误响应(重试后仍失败或不触发重试)
if
resp
.
StatusCode
>=
400
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
...
...
@@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
len
(
body
)
==
0
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Request body is empty"
)
}
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
// 解析请求以获取 image_size(用于图片计费)
imageSize
:=
s
.
extractImageSize
(
body
)
...
...
@@ -1146,7 +1149,7 @@ urlFallbackLoop:
}
// 所有重试都失败,标记限流状态
if
resp
.
StatusCode
==
429
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
}
resp
=
&
http
.
Response
{
StatusCode
:
resp
.
StatusCode
,
...
...
@@ -1200,7 +1203,7 @@ urlFallbackLoop:
goto
handleSuccess
}
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
...
...
@@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
}
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
)
{
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if
statusCode
==
429
{
resetAt
:=
ParseGeminiRateLimitResetTime
(
body
)
...
...
@@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
defaultDur
=
5
*
time
.
Minute
}
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
log
.
Printf
(
"%s status=429 rate_limited reset_in=%v (fallback)"
,
prefix
,
defaultDur
)
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
ra
)
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
if
quotaScope
==
""
{
return
}
if
err
:=
s
.
accountRepo
.
SetAntigravityQuotaScopeLimit
(
ctx
,
account
.
ID
,
quotaScope
,
ra
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed scope=%s error=%v"
,
prefix
,
quotaScope
,
err
)
}
return
}
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
log
.
Printf
(
"%s status=429 rate_limited reset_at=%v reset_in=%v"
,
prefix
,
resetTime
.
Format
(
"15:04:05"
),
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
)
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v"
,
prefix
,
quotaScope
,
resetTime
.
Format
(
"15:04:05"
),
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
if
quotaScope
==
""
{
return
}
if
err
:=
s
.
accountRepo
.
SetAntigravityQuotaScopeLimit
(
ctx
,
account
.
ID
,
quotaScope
,
resetTime
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed scope=%s error=%v"
,
prefix
,
quotaScope
,
err
)
}
return
}
// 其他错误码继续使用 rateLimitService
...
...
backend/internal/service/antigravity_quota_scope.go
0 → 100644
View file @
f6f072cb
package
service
import
(
"strings"
"time"
)
const
antigravityQuotaScopesKey
=
"antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type
AntigravityQuotaScope
string
const
(
AntigravityQuotaScopeClaude
AntigravityQuotaScope
=
"claude"
AntigravityQuotaScopeGeminiText
AntigravityQuotaScope
=
"gemini_text"
AntigravityQuotaScopeGeminiImage
AntigravityQuotaScope
=
"gemini_image"
)
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func
resolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
model
:=
normalizeAntigravityModelName
(
requestedModel
)
if
model
==
""
{
return
""
,
false
}
switch
{
case
strings
.
HasPrefix
(
model
,
"claude-"
)
:
return
AntigravityQuotaScopeClaude
,
true
case
strings
.
HasPrefix
(
model
,
"gemini-"
)
:
if
isImageGenerationModel
(
model
)
{
return
AntigravityQuotaScopeGeminiImage
,
true
}
return
AntigravityQuotaScopeGeminiText
,
true
default
:
return
""
,
false
}
}
func
normalizeAntigravityModelName
(
model
string
)
string
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
model
))
normalized
=
strings
.
TrimPrefix
(
normalized
,
"models/"
)
return
normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
func
(
a
*
Account
)
IsSchedulableForModel
(
requestedModel
string
)
bool
{
if
a
==
nil
{
return
false
}
if
!
a
.
IsSchedulable
()
{
return
false
}
if
a
.
Platform
!=
PlatformAntigravity
{
return
true
}
scope
,
ok
:=
resolveAntigravityQuotaScope
(
requestedModel
)
if
!
ok
{
return
true
}
resetAt
:=
a
.
antigravityQuotaScopeResetAt
(
scope
)
if
resetAt
==
nil
{
return
true
}
now
:=
time
.
Now
()
return
!
now
.
Before
(
*
resetAt
)
}
func
(
a
*
Account
)
antigravityQuotaScopeResetAt
(
scope
AntigravityQuotaScope
)
*
time
.
Time
{
if
a
==
nil
||
a
.
Extra
==
nil
||
scope
==
""
{
return
nil
}
rawScopes
,
ok
:=
a
.
Extra
[
antigravityQuotaScopesKey
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
rawScope
,
ok
:=
rawScopes
[
string
(
scope
)]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
resetAtRaw
,
ok
:=
rawScope
[
"rate_limit_reset_at"
]
.
(
string
)
if
!
ok
||
strings
.
TrimSpace
(
resetAtRaw
)
==
""
{
return
nil
}
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtRaw
)
if
err
!=
nil
{
return
nil
}
return
&
resetAt
}
backend/internal/service/auth_service.go
View file @
f6f072cb
...
...
@@ -52,6 +52,7 @@ type AuthService struct {
emailService
*
EmailService
turnstileService
*
TurnstileService
emailQueueService
*
EmailQueueService
promoService
*
PromoService
}
// NewAuthService 创建认证服务实例
...
...
@@ -62,6 +63,7 @@ func NewAuthService(
emailService
*
EmailService
,
turnstileService
*
TurnstileService
,
emailQueueService
*
EmailQueueService
,
promoService
*
PromoService
,
)
*
AuthService
{
return
&
AuthService
{
userRepo
:
userRepo
,
...
...
@@ -70,16 +72,17 @@ func NewAuthService(
emailService
:
emailService
,
turnstileService
:
turnstileService
,
emailQueueService
:
emailQueueService
,
promoService
:
promoService
,
}
}
// Register 用户注册,返回token和用户
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
)
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
)
}
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
string
)
(
string
,
*
User
,
error
)
{
// RegisterWithVerification 用户注册(支持邮件验证
和优惠码
),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if
s
.
settingService
==
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
...
...
@@ -150,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrServiceUnavailable
}
// 应用优惠码(如果提供)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
// 优惠码应用失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to apply promo code for user %d: %v"
,
user
.
ID
,
err
)
}
else
{
// 重新获取用户信息以获取更新后的余额
if
updatedUser
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
user
.
ID
);
err
==
nil
{
user
=
updatedUser
}
}
}
// 生成token
token
,
err
:=
s
.
GenerateToken
(
user
)
if
err
!=
nil
{
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment