Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
8f0ea7a0
Commit
8f0ea7a0
authored
Mar 14, 2026
by
InCerry
Browse files
Merge branch 'main' into fix/enc_coot
parents
e4a4dfd0
a1dc0089
Changes
81
Show whitespace changes
Inline
Side-by-side
backend/internal/repository/account_repo.go
View file @
8f0ea7a0
...
...
@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
account
.
ID
,
nil
,
buildSchedulerGroupPayload
(
account
.
GroupIDs
));
err
!=
nil
{
logger
.
LegacyPrintf
(
"repository.account"
,
"[SchedulerOutbox] enqueue account update failed: account=%d err=%v"
,
account
.
ID
,
err
)
}
if
account
.
Status
==
service
.
StatusError
||
account
.
Status
==
service
.
StatusDisabled
||
!
account
.
Schedulable
{
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
r
.
syncSchedulerAccountSnapshot
(
ctx
,
account
.
ID
)
}
return
nil
}
...
...
@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
const
nowUTC
=
`to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
const
dailyExpiredExpr
=
`(
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
END
)`
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
const
weeklyExpiredExpr
=
`(
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
END
)`
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
const
nextDailyResetAtExpr
=
`(
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
THEN to_char((
-- Compute today's reset point in the configured timezone, then pick next future one
CASE WHEN NOW() >= (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- NOW() is at or past today's reset point → next reset is tomorrow
THEN (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
+ '1 day'::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- NOW() is before today's reset point → next reset is today
ELSE (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
END
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
ELSE NULL END
)`
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
const
nextWeeklyResetAtExpr
=
`(
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
THEN to_char((
-- Compute this week's reset point in the configured timezone
-- Step 1: get today's date at reset hour in configured tz
-- Step 2: compute days forward to target weekday
-- Step 3: if same day but past reset hour, advance 7 days
CASE
WHEN (
-- days_forward = (target_day - current_day + 7) % 7
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
+ 7) % 7
) = 0 AND NOW() >= (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- Same weekday and past reset hour → next week
THEN (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
+ '7 days'::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
ELSE (
-- Advance to target weekday this week (or next if days_forward > 0)
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
+ ((
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
+ 7) % 7
) || ' days')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
END
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
ELSE NULL END
)`
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
// 日/周额度在周期过期时自动重置为 0 再递增。
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
func
(
r
*
accountRepository
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`UPDATE accounts SET extra = (
...
...
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `
+
dailyExpiredExpr
+
`
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `
+
dailyExpiredExpr
+
`
THEN `
+
nowUTC
+
`
ELSE COALESCE(extra->>'quota_daily_start', `
+
nowUTC
+
`) END
)
-- 固定模式重置时更新下次重置时间
|| CASE WHEN `
+
dailyExpiredExpr
+
` AND `
+
nextDailyResetAtExpr
+
` IS NOT NULL
THEN jsonb_build_object('quota_daily_reset_at', `
+
nextDailyResetAtExpr
+
`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `
+
weeklyExpiredExpr
+
`
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `
+
weeklyExpiredExpr
+
`
THEN `
+
nowUTC
+
`
ELSE COALESCE(extra->>'quota_weekly_start', `
+
nowUTC
+
`) END
)
-- 固定模式重置时更新下次重置时间
|| CASE WHEN `
+
weeklyExpiredExpr
+
` AND `
+
nextWeeklyResetAtExpr
+
` IS NOT NULL
THEN jsonb_build_object('quota_weekly_reset_at', `
+
nextWeeklyResetAtExpr
+
`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
...
...
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
}
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
func
(
r
*
accountRepository
)
ResetQuotaUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
_
,
err
:=
r
.
sql
.
ExecContext
(
ctx
,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
) - 'quota_daily_start' - 'quota_weekly_start'
- 'quota_daily_reset_at' - 'quota_weekly_reset_at'
, updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`
,
id
)
if
err
!=
nil
{
...
...
backend/internal/repository/account_repo_integration_test.go
View file @
8f0ea7a0
...
...
@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
cacheRecorder
.
setAccounts
[
0
]
.
Status
)
}
func
(
s
*
AccountRepoSuite
)
TestUpdate_SyncSchedulerSnapshotOnCredentialsChange
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"sync-credentials-update"
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-5"
:
"gpt-5.1"
,
},
},
})
cacheRecorder
:=
&
schedulerCacheRecorder
{}
s
.
repo
.
schedulerCache
=
cacheRecorder
account
.
Credentials
=
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-5"
:
"gpt-5.2"
,
},
}
err
:=
s
.
repo
.
Update
(
s
.
ctx
,
account
)
s
.
Require
()
.
NoError
(
err
,
"Update"
)
s
.
Require
()
.
Len
(
cacheRecorder
.
setAccounts
,
1
)
s
.
Require
()
.
Equal
(
account
.
ID
,
cacheRecorder
.
setAccounts
[
0
]
.
ID
)
mapping
,
ok
:=
cacheRecorder
.
setAccounts
[
0
]
.
Credentials
[
"model_mapping"
]
.
(
map
[
string
]
any
)
s
.
Require
()
.
True
(
ok
)
s
.
Require
()
.
Equal
(
"gpt-5.2"
,
mapping
[
"gpt-5"
])
}
func
(
s
*
AccountRepoSuite
)
TestDelete
()
{
account
:=
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"to-delete"
})
...
...
backend/internal/server/api_contract_test.go
View file @
8f0ea7a0
...
...
@@ -537,6 +537,7 @@ func TestAPIContracts(t *testing.T) {
"purchase_subscription_url": "",
"min_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"custom_menu_items": []
}
}`
,
...
...
backend/internal/server/middleware/backend_mode_guard.go
0 → 100644
View file @
8f0ea7a0
package
middleware
import
(
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
func
BackendModeUserGuard
(
settingService
*
service
.
SettingService
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
if
settingService
==
nil
||
!
settingService
.
IsBackendModeEnabled
(
c
.
Request
.
Context
())
{
c
.
Next
()
return
}
role
,
_
:=
GetUserRoleFromContext
(
c
)
if
role
==
"admin"
{
c
.
Next
()
return
}
response
.
Forbidden
(
c
,
"Backend mode is active. User self-service is disabled."
)
c
.
Abort
()
}
}
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
// Allows: login, login/2fa, logout, refresh (admin needs these).
// Blocks: register, forgot-password, reset-password, OAuth, etc.
func
BackendModeAuthGuard
(
settingService
*
service
.
SettingService
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
if
settingService
==
nil
||
!
settingService
.
IsBackendModeEnabled
(
c
.
Request
.
Context
())
{
c
.
Next
()
return
}
path
:=
c
.
Request
.
URL
.
Path
// Allow login, 2FA, logout, refresh, public settings
allowedSuffixes
:=
[]
string
{
"/auth/login"
,
"/auth/login/2fa"
,
"/auth/logout"
,
"/auth/refresh"
}
for
_
,
suffix
:=
range
allowedSuffixes
{
if
strings
.
HasSuffix
(
path
,
suffix
)
{
c
.
Next
()
return
}
}
response
.
Forbidden
(
c
,
"Backend mode is active. Registration and self-service auth flows are disabled."
)
c
.
Abort
()
}
}
backend/internal/server/middleware/backend_mode_guard_test.go
0 → 100644
View file @
8f0ea7a0
//go:build unit
package
middleware
import
(
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type
bmSettingRepo
struct
{
values
map
[
string
]
string
}
func
(
r
*
bmSettingRepo
)
Get
(
_
context
.
Context
,
_
string
)
(
*
service
.
Setting
,
error
)
{
panic
(
"unexpected Get call"
)
}
func
(
r
*
bmSettingRepo
)
GetValue
(
_
context
.
Context
,
key
string
)
(
string
,
error
)
{
v
,
ok
:=
r
.
values
[
key
]
if
!
ok
{
return
""
,
service
.
ErrSettingNotFound
}
return
v
,
nil
}
func
(
r
*
bmSettingRepo
)
Set
(
_
context
.
Context
,
_
,
_
string
)
error
{
panic
(
"unexpected Set call"
)
}
func
(
r
*
bmSettingRepo
)
GetMultiple
(
_
context
.
Context
,
_
[]
string
)
(
map
[
string
]
string
,
error
)
{
panic
(
"unexpected GetMultiple call"
)
}
func
(
r
*
bmSettingRepo
)
SetMultiple
(
_
context
.
Context
,
settings
map
[
string
]
string
)
error
{
if
r
.
values
==
nil
{
r
.
values
=
make
(
map
[
string
]
string
,
len
(
settings
))
}
for
key
,
value
:=
range
settings
{
r
.
values
[
key
]
=
value
}
return
nil
}
func
(
r
*
bmSettingRepo
)
GetAll
(
_
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
panic
(
"unexpected GetAll call"
)
}
func
(
r
*
bmSettingRepo
)
Delete
(
_
context
.
Context
,
_
string
)
error
{
panic
(
"unexpected Delete call"
)
}
func
newBackendModeSettingService
(
t
*
testing
.
T
,
enabled
string
)
*
service
.
SettingService
{
t
.
Helper
()
repo
:=
&
bmSettingRepo
{
values
:
map
[
string
]
string
{
service
.
SettingKeyBackendModeEnabled
:
enabled
,
},
}
svc
:=
service
.
NewSettingService
(
repo
,
&
config
.
Config
{})
require
.
NoError
(
t
,
svc
.
UpdateSettings
(
context
.
Background
(),
&
service
.
SystemSettings
{
BackendModeEnabled
:
enabled
==
"true"
,
}))
return
svc
}
func
stringPtr
(
v
string
)
*
string
{
return
&
v
}
func
TestBackendModeUserGuard
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
nilService
bool
enabled
string
role
*
string
wantStatus
int
}{
{
name
:
"disabled_allows_all"
,
enabled
:
"false"
,
role
:
stringPtr
(
"user"
),
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"nil_service_allows_all"
,
nilService
:
true
,
role
:
stringPtr
(
"user"
),
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_admin_allowed"
,
enabled
:
"true"
,
role
:
stringPtr
(
"admin"
),
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_user_blocked"
,
enabled
:
"true"
,
role
:
stringPtr
(
"user"
),
wantStatus
:
http
.
StatusForbidden
,
},
{
name
:
"enabled_no_role_blocked"
,
enabled
:
"true"
,
wantStatus
:
http
.
StatusForbidden
,
},
{
name
:
"enabled_empty_role_blocked"
,
enabled
:
"true"
,
role
:
stringPtr
(
""
),
wantStatus
:
http
.
StatusForbidden
,
},
}
for
_
,
tc
:=
range
tests
{
tc
:=
tc
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
if
tc
.
role
!=
nil
{
role
:=
*
tc
.
role
r
.
Use
(
func
(
c
*
gin
.
Context
)
{
c
.
Set
(
string
(
ContextKeyUserRole
),
role
)
c
.
Next
()
})
}
var
svc
*
service
.
SettingService
if
!
tc
.
nilService
{
svc
=
newBackendModeSettingService
(
t
,
tc
.
enabled
)
}
r
.
Use
(
BackendModeUserGuard
(
svc
))
r
.
GET
(
"/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/test"
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
tc
.
wantStatus
,
w
.
Code
)
})
}
}
func
TestBackendModeAuthGuard
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
nilService
bool
enabled
string
path
string
wantStatus
int
}{
{
name
:
"disabled_allows_all"
,
enabled
:
"false"
,
path
:
"/api/v1/auth/register"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"nil_service_allows_all"
,
nilService
:
true
,
path
:
"/api/v1/auth/register"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_login"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/login"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_login_2fa"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/login/2fa"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_logout"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/logout"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_refresh"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/refresh"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_blocks_register"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/register"
,
wantStatus
:
http
.
StatusForbidden
,
},
{
name
:
"enabled_blocks_forgot_password"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/forgot-password"
,
wantStatus
:
http
.
StatusForbidden
,
},
}
for
_
,
tc
:=
range
tests
{
tc
:=
tc
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
var
svc
*
service
.
SettingService
if
!
tc
.
nilService
{
svc
=
newBackendModeSettingService
(
t
,
tc
.
enabled
)
}
r
.
Use
(
BackendModeAuthGuard
(
svc
))
r
.
Any
(
"/*path"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
tc
.
path
,
nil
)
r
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
tc
.
wantStatus
,
w
.
Code
)
})
}
}
backend/internal/server/router.go
View file @
8f0ea7a0
...
...
@@ -107,9 +107,9 @@ func registerRoutes(
v1
:=
r
.
Group
(
"/api/v1"
)
// 注册各模块路由
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
,
redisClient
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterSoraClientRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
,
redisClient
,
settingService
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
,
settingService
)
routes
.
RegisterSoraClientRoutes
(
v1
,
h
,
jwtAuth
,
settingService
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
settingService
,
cfg
)
}
backend/internal/server/routes/auth.go
View file @
8f0ea7a0
...
...
@@ -6,6 +6,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
...
...
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
h
*
handler
.
Handlers
,
jwtAuth
servermiddleware
.
JWTAuthMiddleware
,
redisClient
*
redis
.
Client
,
settingService
*
service
.
SettingService
,
)
{
// 创建速率限制器
rateLimiter
:=
middleware
.
NewRateLimiter
(
redisClient
)
// 公开接口
auth
:=
v1
.
Group
(
"/auth"
)
auth
.
Use
(
servermiddleware
.
BackendModeAuthGuard
(
settingService
))
{
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
auth
.
POST
(
"/register"
,
rateLimiter
.
LimitWithOptions
(
"auth-register"
,
5
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
...
...
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
// 需要认证的当前用户信息
authenticated
:=
v1
.
Group
(
""
)
authenticated
.
Use
(
gin
.
HandlerFunc
(
jwtAuth
))
authenticated
.
Use
(
servermiddleware
.
BackendModeUserGuard
(
settingService
))
{
authenticated
.
GET
(
"/auth/me"
,
h
.
Auth
.
GetCurrentUser
)
// 撤销所有会话(需要认证)
...
...
backend/internal/server/routes/auth_rate_limit_test.go
View file @
8f0ea7a0
...
...
@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
c
.
Next
()
}),
redisClient
,
nil
,
)
return
router
...
...
backend/internal/server/routes/sora_client.go
View file @
8f0ea7a0
...
...
@@ -3,6 +3,7 @@ package routes
import
(
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
...
...
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
v1
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
,
jwtAuth
middleware
.
JWTAuthMiddleware
,
settingService
*
service
.
SettingService
,
)
{
if
h
.
SoraClient
==
nil
{
return
...
...
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
authenticated
:=
v1
.
Group
(
"/sora"
)
authenticated
.
Use
(
gin
.
HandlerFunc
(
jwtAuth
))
authenticated
.
Use
(
middleware
.
BackendModeUserGuard
(
settingService
))
{
authenticated
.
POST
(
"/generate"
,
h
.
SoraClient
.
Generate
)
authenticated
.
GET
(
"/generations"
,
h
.
SoraClient
.
ListGenerations
)
...
...
backend/internal/server/routes/user.go
View file @
8f0ea7a0
...
...
@@ -3,6 +3,7 @@ package routes
import
(
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
...
...
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
v1
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
,
jwtAuth
middleware
.
JWTAuthMiddleware
,
settingService
*
service
.
SettingService
,
)
{
authenticated
:=
v1
.
Group
(
""
)
authenticated
.
Use
(
gin
.
HandlerFunc
(
jwtAuth
))
authenticated
.
Use
(
middleware
.
BackendModeUserGuard
(
settingService
))
{
// 用户接口
user
:=
authenticated
.
Group
(
"/user"
)
...
...
backend/internal/service/account.go
View file @
8f0ea7a0
...
...
@@ -3,6 +3,7 @@ package service
import
(
"encoding/json"
"errors"
"hash/fnv"
"reflect"
"sort"
...
...
@@ -522,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping,返回原始模型名
func
(
a
*
Account
)
GetMappedModel
(
requestedModel
string
)
string
{
mappedModel
,
_
:=
a
.
ResolveMappedModel
(
requestedModel
)
return
mappedModel
}
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
func
(
a
*
Account
)
ResolveMappedModel
(
requestedModel
string
)
(
mappedModel
string
,
matched
bool
)
{
mapping
:=
a
.
GetModelMapping
()
if
len
(
mapping
)
==
0
{
return
requestedModel
return
requestedModel
,
false
}
// 精确匹配优先
if
mappedModel
,
exists
:=
mapping
[
requestedModel
];
exists
{
return
mappedModel
return
mappedModel
,
true
}
// 通配符匹配(最长优先)
return
matchWildcardMapping
(
mapping
,
requestedModel
)
return
matchWildcardMapping
Result
(
mapping
,
requestedModel
)
}
func
(
a
*
Account
)
GetBaseURL
()
string
{
...
...
@@ -605,9 +613,7 @@ func matchWildcard(pattern, str string) bool {
return
matchAntigravityWildcard
(
pattern
,
str
)
}
// matchWildcardMapping 通配符映射匹配(最长优先)
// 如果没有匹配,返回原始字符串
func
matchWildcardMapping
(
mapping
map
[
string
]
string
,
requestedModel
string
)
string
{
func
matchWildcardMappingResult
(
mapping
map
[
string
]
string
,
requestedModel
string
)
(
string
,
bool
)
{
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
type
patternMatch
struct
{
pattern
string
...
...
@@ -622,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
}
if
len
(
matches
)
==
0
{
return
requestedModel
// 无匹配,返回原始模型名
return
requestedModel
,
false
// 无匹配,返回原始模型名
}
// 按 pattern 长度降序排序
...
...
@@ -633,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
return
matches
[
i
]
.
pattern
<
matches
[
j
]
.
pattern
})
return
matches
[
0
]
.
target
return
matches
[
0
]
.
target
,
true
}
func
(
a
*
Account
)
IsCustomErrorCodesEnabled
()
bool
{
...
...
@@ -651,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
// IsPoolMode 检查 API Key 账号是否启用池模式。
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
func
(
a
*
Account
)
IsPoolMode
()
bool
{
if
a
.
Type
!=
AccountTypeAPIKey
||
a
.
Credentials
==
nil
{
if
!
a
.
IsAPIKeyOrBedrock
()
||
a
.
Credentials
==
nil
{
return
false
}
if
v
,
ok
:=
a
.
Credentials
[
"pool_mode"
];
ok
{
...
...
@@ -766,11 +772,16 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
}
func
(
a
*
Account
)
IsBedrock
()
bool
{
return
a
.
Platform
==
PlatformAnthropic
&&
(
a
.
Type
==
AccountTypeBedrock
||
a
.
Type
==
AccountTypeBedrockAPIKey
)
return
a
.
Platform
==
PlatformAnthropic
&&
a
.
Type
==
AccountTypeBedrock
}
func
(
a
*
Account
)
IsBedrockAPIKey
()
bool
{
return
a
.
Platform
==
PlatformAnthropic
&&
a
.
Type
==
AccountTypeBedrockAPIKey
return
a
.
IsBedrock
()
&&
a
.
GetCredential
(
"auth_mode"
)
==
"apikey"
}
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
func
(
a
*
Account
)
IsAPIKeyOrBedrock
()
bool
{
return
a
.
Type
==
AccountTypeAPIKey
||
a
.
Type
==
AccountTypeBedrock
}
func
(
a
*
Account
)
IsOpenAI
()
bool
{
...
...
@@ -1269,6 +1280,240 @@ func (a *Account) getExtraTime(key string) time.Time {
return
time
.
Time
{}
}
// getExtraString 从 Extra 中读取指定 key 的字符串值
func
(
a
*
Account
)
getExtraString
(
key
string
)
string
{
if
a
.
Extra
==
nil
{
return
""
}
if
v
,
ok
:=
a
.
Extra
[
key
];
ok
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
return
s
}
}
return
""
}
// getExtraInt 从 Extra 中读取指定 key 的 int 值
func
(
a
*
Account
)
getExtraInt
(
key
string
)
int
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
key
];
ok
{
return
int
(
parseExtraFloat64
(
v
))
}
return
0
}
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
func
(
a
*
Account
)
GetQuotaDailyResetMode
()
string
{
if
m
:=
a
.
getExtraString
(
"quota_daily_reset_mode"
);
m
==
"fixed"
{
return
"fixed"
}
return
"rolling"
}
// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0
func
(
a
*
Account
)
GetQuotaDailyResetHour
()
int
{
return
a
.
getExtraInt
(
"quota_daily_reset_hour"
)
}
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
func
(
a
*
Account
)
GetQuotaWeeklyResetMode
()
string
{
if
m
:=
a
.
getExtraString
(
"quota_weekly_reset_mode"
);
m
==
"fixed"
{
return
"fixed"
}
return
"rolling"
}
// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一)
func
(
a
*
Account
)
GetQuotaWeeklyResetDay
()
int
{
if
a
.
Extra
==
nil
{
return
1
}
if
_
,
ok
:=
a
.
Extra
[
"quota_weekly_reset_day"
];
!
ok
{
return
1
}
return
a
.
getExtraInt
(
"quota_weekly_reset_day"
)
}
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0
func
(
a
*
Account
)
GetQuotaWeeklyResetHour
()
int
{
return
a
.
getExtraInt
(
"quota_weekly_reset_hour"
)
}
// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC"
func
(
a
*
Account
)
GetQuotaResetTimezone
()
string
{
if
tz
:=
a
.
getExtraString
(
"quota_reset_timezone"
);
tz
!=
""
{
return
tz
}
return
"UTC"
}
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func
nextFixedDailyReset
(
hour
int
,
tz
*
time
.
Location
,
after
time
.
Time
)
time
.
Time
{
t
:=
after
.
In
(
tz
)
today
:=
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
hour
,
0
,
0
,
0
,
tz
)
if
!
after
.
Before
(
today
)
{
return
today
.
AddDate
(
0
,
0
,
1
)
}
return
today
}
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
func
lastFixedDailyReset
(
hour
int
,
tz
*
time
.
Location
,
now
time
.
Time
)
time
.
Time
{
t
:=
now
.
In
(
tz
)
today
:=
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
hour
,
0
,
0
,
0
,
tz
)
if
now
.
Before
(
today
)
{
return
today
.
AddDate
(
0
,
0
,
-
1
)
}
return
today
}
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
func
nextFixedWeeklyReset
(
day
,
hour
int
,
tz
*
time
.
Location
,
after
time
.
Time
)
time
.
Time
{
t
:=
after
.
In
(
tz
)
todayReset
:=
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
hour
,
0
,
0
,
0
,
tz
)
currentDay
:=
int
(
todayReset
.
Weekday
())
daysForward
:=
(
day
-
currentDay
+
7
)
%
7
if
daysForward
==
0
&&
!
after
.
Before
(
todayReset
)
{
daysForward
=
7
}
return
todayReset
.
AddDate
(
0
,
0
,
daysForward
)
}
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
func
lastFixedWeeklyReset
(
day
,
hour
int
,
tz
*
time
.
Location
,
now
time
.
Time
)
time
.
Time
{
t
:=
now
.
In
(
tz
)
todayReset
:=
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
hour
,
0
,
0
,
0
,
tz
)
currentDay
:=
int
(
todayReset
.
Weekday
())
daysBack
:=
(
currentDay
-
day
+
7
)
%
7
if
daysBack
==
0
&&
now
.
Before
(
todayReset
)
{
daysBack
=
7
}
return
todayReset
.
AddDate
(
0
,
0
,
-
daysBack
)
}
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
func
(
a
*
Account
)
isFixedDailyPeriodExpired
(
periodStart
time
.
Time
)
bool
{
if
periodStart
.
IsZero
()
{
return
true
}
tz
,
err
:=
time
.
LoadLocation
(
a
.
GetQuotaResetTimezone
())
if
err
!=
nil
{
tz
=
time
.
UTC
}
lastReset
:=
lastFixedDailyReset
(
a
.
GetQuotaDailyResetHour
(),
tz
,
time
.
Now
())
return
periodStart
.
Before
(
lastReset
)
}
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
func
(
a
*
Account
)
isFixedWeeklyPeriodExpired
(
periodStart
time
.
Time
)
bool
{
if
periodStart
.
IsZero
()
{
return
true
}
tz
,
err
:=
time
.
LoadLocation
(
a
.
GetQuotaResetTimezone
())
if
err
!=
nil
{
tz
=
time
.
UTC
}
lastReset
:=
lastFixedWeeklyReset
(
a
.
GetQuotaWeeklyResetDay
(),
a
.
GetQuotaWeeklyResetHour
(),
tz
,
time
.
Now
())
return
periodStart
.
Before
(
lastReset
)
}
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
// 在保存账号配置时调用
func
ComputeQuotaResetAt
(
extra
map
[
string
]
any
)
{
now
:=
time
.
Now
()
tzName
,
_
:=
extra
[
"quota_reset_timezone"
]
.
(
string
)
if
tzName
==
""
{
tzName
=
"UTC"
}
tz
,
err
:=
time
.
LoadLocation
(
tzName
)
if
err
!=
nil
{
tz
=
time
.
UTC
}
// 日配额固定重置时间
if
mode
,
_
:=
extra
[
"quota_daily_reset_mode"
]
.
(
string
);
mode
==
"fixed"
{
hour
:=
int
(
parseExtraFloat64
(
extra
[
"quota_daily_reset_hour"
]))
if
hour
<
0
||
hour
>
23
{
hour
=
0
}
resetAt
:=
nextFixedDailyReset
(
hour
,
tz
,
now
)
extra
[
"quota_daily_reset_at"
]
=
resetAt
.
UTC
()
.
Format
(
time
.
RFC3339
)
}
else
{
delete
(
extra
,
"quota_daily_reset_at"
)
}
// 周配额固定重置时间
if
mode
,
_
:=
extra
[
"quota_weekly_reset_mode"
]
.
(
string
);
mode
==
"fixed"
{
day
:=
1
// 默认周一
if
d
,
ok
:=
extra
[
"quota_weekly_reset_day"
];
ok
{
day
=
int
(
parseExtraFloat64
(
d
))
}
if
day
<
0
||
day
>
6
{
day
=
1
}
hour
:=
int
(
parseExtraFloat64
(
extra
[
"quota_weekly_reset_hour"
]))
if
hour
<
0
||
hour
>
23
{
hour
=
0
}
resetAt
:=
nextFixedWeeklyReset
(
day
,
hour
,
tz
,
now
)
extra
[
"quota_weekly_reset_at"
]
=
resetAt
.
UTC
()
.
Format
(
time
.
RFC3339
)
}
else
{
delete
(
extra
,
"quota_weekly_reset_at"
)
}
}
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
func
ValidateQuotaResetConfig
(
extra
map
[
string
]
any
)
error
{
if
extra
==
nil
{
return
nil
}
// 校验时区
if
tz
,
ok
:=
extra
[
"quota_reset_timezone"
]
.
(
string
);
ok
&&
tz
!=
""
{
if
_
,
err
:=
time
.
LoadLocation
(
tz
);
err
!=
nil
{
return
errors
.
New
(
"invalid quota_reset_timezone: must be a valid IANA timezone name"
)
}
}
// 日配额重置模式
if
mode
,
ok
:=
extra
[
"quota_daily_reset_mode"
]
.
(
string
);
ok
{
if
mode
!=
"rolling"
&&
mode
!=
"fixed"
{
return
errors
.
New
(
"quota_daily_reset_mode must be 'rolling' or 'fixed'"
)
}
}
// 日配额重置小时
if
v
,
ok
:=
extra
[
"quota_daily_reset_hour"
];
ok
{
hour
:=
int
(
parseExtraFloat64
(
v
))
if
hour
<
0
||
hour
>
23
{
return
errors
.
New
(
"quota_daily_reset_hour must be between 0 and 23"
)
}
}
// 周配额重置模式
if
mode
,
ok
:=
extra
[
"quota_weekly_reset_mode"
]
.
(
string
);
ok
{
if
mode
!=
"rolling"
&&
mode
!=
"fixed"
{
return
errors
.
New
(
"quota_weekly_reset_mode must be 'rolling' or 'fixed'"
)
}
}
// 周配额重置星期几
if
v
,
ok
:=
extra
[
"quota_weekly_reset_day"
];
ok
{
day
:=
int
(
parseExtraFloat64
(
v
))
if
day
<
0
||
day
>
6
{
return
errors
.
New
(
"quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)"
)
}
}
// 周配额重置小时
if
v
,
ok
:=
extra
[
"quota_weekly_reset_hour"
];
ok
{
hour
:=
int
(
parseExtraFloat64
(
v
))
if
hour
<
0
||
hour
>
23
{
return
errors
.
New
(
"quota_weekly_reset_hour must be between 0 and 23"
)
}
}
return
nil
}
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
func
(
a
*
Account
)
HasAnyQuotaLimit
()
bool
{
return
a
.
GetQuotaLimit
()
>
0
||
a
.
GetQuotaDailyLimit
()
>
0
||
a
.
GetQuotaWeeklyLimit
()
>
0
...
...
@@ -1291,14 +1536,26 @@ func (a *Account) IsQuotaExceeded() bool {
// 日额度(周期过期视为未超限,下次 increment 会重置)
if
limit
:=
a
.
GetQuotaDailyLimit
();
limit
>
0
{
start
:=
a
.
getExtraTime
(
"quota_daily_start"
)
if
!
isPeriodExpired
(
start
,
24
*
time
.
Hour
)
&&
a
.
GetQuotaDailyUsed
()
>=
limit
{
var
expired
bool
if
a
.
GetQuotaDailyResetMode
()
==
"fixed"
{
expired
=
a
.
isFixedDailyPeriodExpired
(
start
)
}
else
{
expired
=
isPeriodExpired
(
start
,
24
*
time
.
Hour
)
}
if
!
expired
&&
a
.
GetQuotaDailyUsed
()
>=
limit
{
return
true
}
}
// 周额度
if
limit
:=
a
.
GetQuotaWeeklyLimit
();
limit
>
0
{
start
:=
a
.
getExtraTime
(
"quota_weekly_start"
)
if
!
isPeriodExpired
(
start
,
7
*
24
*
time
.
Hour
)
&&
a
.
GetQuotaWeeklyUsed
()
>=
limit
{
var
expired
bool
if
a
.
GetQuotaWeeklyResetMode
()
==
"fixed"
{
expired
=
a
.
isFixedWeeklyPeriodExpired
(
start
)
}
else
{
expired
=
isPeriodExpired
(
start
,
7
*
24
*
time
.
Hour
)
}
if
!
expired
&&
a
.
GetQuotaWeeklyUsed
()
>=
limit
{
return
true
}
}
...
...
backend/internal/service/account_quota_reset_test.go
0 → 100644
View file @
8f0ea7a0
//go:build unit
package
service
import
(
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// nextFixedDailyReset
// ---------------------------------------------------------------------------
func
TestNextFixedDailyReset_BeforeResetHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-14 06:00 UTC, reset hour = 9
after
:=
time
.
Date
(
2026
,
3
,
14
,
6
,
0
,
0
,
0
,
tz
)
got
:=
nextFixedDailyReset
(
9
,
tz
,
after
)
want
:=
time
.
Date
(
2026
,
3
,
14
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestNextFixedDailyReset_AtResetHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// Exactly at reset hour → should return tomorrow
after
:=
time
.
Date
(
2026
,
3
,
14
,
9
,
0
,
0
,
0
,
tz
)
got
:=
nextFixedDailyReset
(
9
,
tz
,
after
)
want
:=
time
.
Date
(
2026
,
3
,
15
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestNextFixedDailyReset_AfterResetHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// After reset hour → should return tomorrow
after
:=
time
.
Date
(
2026
,
3
,
14
,
15
,
30
,
0
,
0
,
tz
)
got
:=
nextFixedDailyReset
(
9
,
tz
,
after
)
want
:=
time
.
Date
(
2026
,
3
,
15
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestNextFixedDailyReset_MidnightReset
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// Reset at hour 0 (midnight), currently 23:59
after
:=
time
.
Date
(
2026
,
3
,
14
,
23
,
59
,
0
,
0
,
tz
)
got
:=
nextFixedDailyReset
(
0
,
tz
,
after
)
want
:=
time
.
Date
(
2026
,
3
,
15
,
0
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestNextFixedDailyReset_NonUTCTimezone
(
t
*
testing
.
T
)
{
tz
,
err
:=
time
.
LoadLocation
(
"Asia/Shanghai"
)
require
.
NoError
(
t
,
err
)
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
after
:=
time
.
Date
(
2026
,
3
,
14
,
7
,
0
,
0
,
0
,
time
.
UTC
)
got
:=
nextFixedDailyReset
(
9
,
tz
,
after
)
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
want
:=
time
.
Date
(
2026
,
3
,
15
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
// ---------------------------------------------------------------------------
// lastFixedDailyReset
// ---------------------------------------------------------------------------
func
TestLastFixedDailyReset_BeforeResetHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
now
:=
time
.
Date
(
2026
,
3
,
14
,
6
,
0
,
0
,
0
,
tz
)
got
:=
lastFixedDailyReset
(
9
,
tz
,
now
)
// Before today's 9:00 → yesterday 9:00
want
:=
time
.
Date
(
2026
,
3
,
13
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestLastFixedDailyReset_AtResetHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
now
:=
time
.
Date
(
2026
,
3
,
14
,
9
,
0
,
0
,
0
,
tz
)
got
:=
lastFixedDailyReset
(
9
,
tz
,
now
)
// At exactly 9:00 → today 9:00
want
:=
time
.
Date
(
2026
,
3
,
14
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestLastFixedDailyReset_AfterResetHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
now
:=
time
.
Date
(
2026
,
3
,
14
,
15
,
0
,
0
,
0
,
tz
)
got
:=
lastFixedDailyReset
(
9
,
tz
,
now
)
// After 9:00 → today 9:00
want
:=
time
.
Date
(
2026
,
3
,
14
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
// ---------------------------------------------------------------------------
// nextFixedWeeklyReset
// ---------------------------------------------------------------------------
func
TestNextFixedWeeklyReset_TargetDayAhead
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
after
:=
time
.
Date
(
2026
,
3
,
14
,
10
,
0
,
0
,
0
,
tz
)
got
:=
nextFixedWeeklyReset
(
1
,
9
,
tz
,
after
)
// Next Monday = 2026-03-16
want
:=
time
.
Date
(
2026
,
3
,
16
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestNextFixedWeeklyReset_TargetDayToday_BeforeHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
after
:=
time
.
Date
(
2026
,
3
,
16
,
6
,
0
,
0
,
0
,
tz
)
got
:=
nextFixedWeeklyReset
(
1
,
9
,
tz
,
after
)
// Today at 9:00
want
:=
time
.
Date
(
2026
,
3
,
16
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestNextFixedWeeklyReset_TargetDayToday_AtHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
after
:=
time
.
Date
(
2026
,
3
,
16
,
9
,
0
,
0
,
0
,
tz
)
got
:=
nextFixedWeeklyReset
(
1
,
9
,
tz
,
after
)
// Next Monday at 9:00
want
:=
time
.
Date
(
2026
,
3
,
23
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestNextFixedWeeklyReset_TargetDayToday_AfterHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
after
:=
time
.
Date
(
2026
,
3
,
16
,
15
,
0
,
0
,
0
,
tz
)
got
:=
nextFixedWeeklyReset
(
1
,
9
,
tz
,
after
)
// Next Monday at 9:00
want
:=
time
.
Date
(
2026
,
3
,
23
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestNextFixedWeeklyReset_TargetDayPast
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
after
:=
time
.
Date
(
2026
,
3
,
18
,
10
,
0
,
0
,
0
,
tz
)
got
:=
nextFixedWeeklyReset
(
1
,
9
,
tz
,
after
)
// Next Monday = 2026-03-23
want
:=
time
.
Date
(
2026
,
3
,
23
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestNextFixedWeeklyReset_Sunday
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
after
:=
time
.
Date
(
2026
,
3
,
14
,
10
,
0
,
0
,
0
,
tz
)
got
:=
nextFixedWeeklyReset
(
0
,
0
,
tz
,
after
)
// Next Sunday = 2026-03-15
want
:=
time
.
Date
(
2026
,
3
,
15
,
0
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
// ---------------------------------------------------------------------------
// lastFixedWeeklyReset
// ---------------------------------------------------------------------------
func
TestLastFixedWeeklyReset_SameDay_AfterHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
now
:=
time
.
Date
(
2026
,
3
,
16
,
15
,
0
,
0
,
0
,
tz
)
got
:=
lastFixedWeeklyReset
(
1
,
9
,
tz
,
now
)
// Today at 9:00
want
:=
time
.
Date
(
2026
,
3
,
16
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestLastFixedWeeklyReset_SameDay_BeforeHour
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
now
:=
time
.
Date
(
2026
,
3
,
16
,
6
,
0
,
0
,
0
,
tz
)
got
:=
lastFixedWeeklyReset
(
1
,
9
,
tz
,
now
)
// Last Monday at 9:00 = 2026-03-09
want
:=
time
.
Date
(
2026
,
3
,
9
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
func
TestLastFixedWeeklyReset_DifferentDay
(
t
*
testing
.
T
)
{
tz
:=
time
.
UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
now
:=
time
.
Date
(
2026
,
3
,
18
,
10
,
0
,
0
,
0
,
tz
)
got
:=
lastFixedWeeklyReset
(
1
,
9
,
tz
,
now
)
// Last Monday = 2026-03-16
want
:=
time
.
Date
(
2026
,
3
,
16
,
9
,
0
,
0
,
0
,
tz
)
assert
.
Equal
(
t
,
want
,
got
)
}
// ---------------------------------------------------------------------------
// isFixedDailyPeriodExpired
// ---------------------------------------------------------------------------
func
TestIsFixedDailyPeriodExpired_ZeroPeriodStart
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Extra
:
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"fixed"
,
"quota_daily_reset_hour"
:
float64
(
9
),
"quota_reset_timezone"
:
"UTC"
,
}}
assert
.
True
(
t
,
a
.
isFixedDailyPeriodExpired
(
time
.
Time
{}))
}
func
TestIsFixedDailyPeriodExpired_NotExpired
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Extra
:
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"fixed"
,
"quota_daily_reset_hour"
:
float64
(
9
),
"quota_reset_timezone"
:
"UTC"
,
}}
// Period started after the most recent reset → not expired
// (This test uses a time very close to "now", which is after the last reset)
periodStart
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Minute
)
assert
.
False
(
t
,
a
.
isFixedDailyPeriodExpired
(
periodStart
))
}
func
TestIsFixedDailyPeriodExpired_Expired
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Extra
:
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"fixed"
,
"quota_daily_reset_hour"
:
float64
(
9
),
"quota_reset_timezone"
:
"UTC"
,
}}
// Period started 3 days ago → definitely expired
periodStart
:=
time
.
Now
()
.
Add
(
-
72
*
time
.
Hour
)
assert
.
True
(
t
,
a
.
isFixedDailyPeriodExpired
(
periodStart
))
}
func
TestIsFixedDailyPeriodExpired_InvalidTimezone
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Extra
:
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"fixed"
,
"quota_daily_reset_hour"
:
float64
(
9
),
"quota_reset_timezone"
:
"Invalid/Timezone"
,
}}
// Invalid timezone falls back to UTC
periodStart
:=
time
.
Now
()
.
Add
(
-
72
*
time
.
Hour
)
assert
.
True
(
t
,
a
.
isFixedDailyPeriodExpired
(
periodStart
))
}
// ---------------------------------------------------------------------------
// isFixedWeeklyPeriodExpired
// ---------------------------------------------------------------------------
func
TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Extra
:
map
[
string
]
any
{
"quota_weekly_reset_mode"
:
"fixed"
,
"quota_weekly_reset_day"
:
float64
(
1
),
"quota_weekly_reset_hour"
:
float64
(
9
),
"quota_reset_timezone"
:
"UTC"
,
}}
assert
.
True
(
t
,
a
.
isFixedWeeklyPeriodExpired
(
time
.
Time
{}))
}
func
TestIsFixedWeeklyPeriodExpired_NotExpired
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Extra
:
map
[
string
]
any
{
"quota_weekly_reset_mode"
:
"fixed"
,
"quota_weekly_reset_day"
:
float64
(
1
),
"quota_weekly_reset_hour"
:
float64
(
9
),
"quota_reset_timezone"
:
"UTC"
,
}}
// Period started 1 minute ago → not expired
periodStart
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Minute
)
assert
.
False
(
t
,
a
.
isFixedWeeklyPeriodExpired
(
periodStart
))
}
func
TestIsFixedWeeklyPeriodExpired_Expired
(
t
*
testing
.
T
)
{
a
:=
&
Account
{
Extra
:
map
[
string
]
any
{
"quota_weekly_reset_mode"
:
"fixed"
,
"quota_weekly_reset_day"
:
float64
(
1
),
"quota_weekly_reset_hour"
:
float64
(
9
),
"quota_reset_timezone"
:
"UTC"
,
}}
// Period started 10 days ago → definitely expired
periodStart
:=
time
.
Now
()
.
Add
(
-
240
*
time
.
Hour
)
assert
.
True
(
t
,
a
.
isFixedWeeklyPeriodExpired
(
periodStart
))
}
// ---------------------------------------------------------------------------
// ValidateQuotaResetConfig
// ---------------------------------------------------------------------------
func
TestValidateQuotaResetConfig_NilExtra
(
t
*
testing
.
T
)
{
assert
.
NoError
(
t
,
ValidateQuotaResetConfig
(
nil
))
}
func
TestValidateQuotaResetConfig_EmptyExtra
(
t
*
testing
.
T
)
{
assert
.
NoError
(
t
,
ValidateQuotaResetConfig
(
map
[
string
]
any
{}))
}
func
TestValidateQuotaResetConfig_ValidFixed
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"fixed"
,
"quota_daily_reset_hour"
:
float64
(
9
),
"quota_weekly_reset_mode"
:
"fixed"
,
"quota_weekly_reset_day"
:
float64
(
1
),
"quota_weekly_reset_hour"
:
float64
(
0
),
"quota_reset_timezone"
:
"Asia/Shanghai"
,
}
assert
.
NoError
(
t
,
ValidateQuotaResetConfig
(
extra
))
}
func
TestValidateQuotaResetConfig_ValidRolling
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"rolling"
,
"quota_weekly_reset_mode"
:
"rolling"
,
}
assert
.
NoError
(
t
,
ValidateQuotaResetConfig
(
extra
))
}
func
TestValidateQuotaResetConfig_InvalidTimezone
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_reset_timezone"
:
"Not/A/Timezone"
,
}
err
:=
ValidateQuotaResetConfig
(
extra
)
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"quota_reset_timezone"
)
}
func
TestValidateQuotaResetConfig_InvalidDailyMode
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"invalid"
,
}
err
:=
ValidateQuotaResetConfig
(
extra
)
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"quota_daily_reset_mode"
)
}
func
TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_hour"
:
float64
(
24
),
}
err
:=
ValidateQuotaResetConfig
(
extra
)
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"quota_daily_reset_hour"
)
}
func
TestValidateQuotaResetConfig_InvalidDailyHour_Negative
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_hour"
:
float64
(
-
1
),
}
err
:=
ValidateQuotaResetConfig
(
extra
)
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"quota_daily_reset_hour"
)
}
func
TestValidateQuotaResetConfig_InvalidWeeklyMode
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_weekly_reset_mode"
:
"unknown"
,
}
err
:=
ValidateQuotaResetConfig
(
extra
)
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"quota_weekly_reset_mode"
)
}
func
TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_weekly_reset_day"
:
float64
(
7
),
}
err
:=
ValidateQuotaResetConfig
(
extra
)
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"quota_weekly_reset_day"
)
}
func
TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_weekly_reset_day"
:
float64
(
-
1
),
}
err
:=
ValidateQuotaResetConfig
(
extra
)
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"quota_weekly_reset_day"
)
}
func
TestValidateQuotaResetConfig_InvalidWeeklyHour
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_weekly_reset_hour"
:
float64
(
25
),
}
err
:=
ValidateQuotaResetConfig
(
extra
)
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"quota_weekly_reset_hour"
)
}
func
TestValidateQuotaResetConfig_BoundaryValues
(
t
*
testing
.
T
)
{
// All boundary values should be valid
extra
:=
map
[
string
]
any
{
"quota_daily_reset_hour"
:
float64
(
23
),
"quota_weekly_reset_day"
:
float64
(
0
),
// Sunday
"quota_weekly_reset_hour"
:
float64
(
0
),
"quota_reset_timezone"
:
"UTC"
,
}
assert
.
NoError
(
t
,
ValidateQuotaResetConfig
(
extra
))
extra2
:=
map
[
string
]
any
{
"quota_daily_reset_hour"
:
float64
(
0
),
"quota_weekly_reset_day"
:
float64
(
6
),
// Saturday
"quota_weekly_reset_hour"
:
float64
(
23
),
}
assert
.
NoError
(
t
,
ValidateQuotaResetConfig
(
extra2
))
}
// ---------------------------------------------------------------------------
// ComputeQuotaResetAt
// ---------------------------------------------------------------------------
func
TestComputeQuotaResetAt_RollingMode_NoResetAt
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"rolling"
,
"quota_weekly_reset_mode"
:
"rolling"
,
}
ComputeQuotaResetAt
(
extra
)
_
,
hasDailyResetAt
:=
extra
[
"quota_daily_reset_at"
]
_
,
hasWeeklyResetAt
:=
extra
[
"quota_weekly_reset_at"
]
assert
.
False
(
t
,
hasDailyResetAt
,
"rolling mode should not set quota_daily_reset_at"
)
assert
.
False
(
t
,
hasWeeklyResetAt
,
"rolling mode should not set quota_weekly_reset_at"
)
}
func
TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"rolling"
,
"quota_weekly_reset_mode"
:
"rolling"
,
"quota_daily_reset_at"
:
"2026-03-14T09:00:00Z"
,
"quota_weekly_reset_at"
:
"2026-03-16T09:00:00Z"
,
}
ComputeQuotaResetAt
(
extra
)
_
,
hasDailyResetAt
:=
extra
[
"quota_daily_reset_at"
]
_
,
hasWeeklyResetAt
:=
extra
[
"quota_weekly_reset_at"
]
assert
.
False
(
t
,
hasDailyResetAt
,
"rolling mode should remove quota_daily_reset_at"
)
assert
.
False
(
t
,
hasWeeklyResetAt
,
"rolling mode should remove quota_weekly_reset_at"
)
}
func
TestComputeQuotaResetAt_FixedDaily_SetsResetAt
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"fixed"
,
"quota_daily_reset_hour"
:
float64
(
9
),
"quota_reset_timezone"
:
"UTC"
,
}
ComputeQuotaResetAt
(
extra
)
resetAtStr
,
ok
:=
extra
[
"quota_daily_reset_at"
]
.
(
string
)
require
.
True
(
t
,
ok
,
"quota_daily_reset_at should be set"
)
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtStr
)
require
.
NoError
(
t
,
err
)
// Reset time should be in the future
assert
.
True
(
t
,
resetAt
.
After
(
time
.
Now
()),
"reset_at should be in the future"
)
// Reset hour should be 9 UTC
assert
.
Equal
(
t
,
9
,
resetAt
.
UTC
()
.
Hour
())
}
func
TestComputeQuotaResetAt_FixedWeekly_SetsResetAt
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_weekly_reset_mode"
:
"fixed"
,
"quota_weekly_reset_day"
:
float64
(
1
),
// Monday
"quota_weekly_reset_hour"
:
float64
(
0
),
"quota_reset_timezone"
:
"UTC"
,
}
ComputeQuotaResetAt
(
extra
)
resetAtStr
,
ok
:=
extra
[
"quota_weekly_reset_at"
]
.
(
string
)
require
.
True
(
t
,
ok
,
"quota_weekly_reset_at should be set"
)
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtStr
)
require
.
NoError
(
t
,
err
)
// Reset time should be in the future
assert
.
True
(
t
,
resetAt
.
After
(
time
.
Now
()),
"reset_at should be in the future"
)
// Reset day should be Monday
assert
.
Equal
(
t
,
time
.
Monday
,
resetAt
.
UTC
()
.
Weekday
())
}
func
TestComputeQuotaResetAt_FixedDaily_WithTimezone
(
t
*
testing
.
T
)
{
tz
,
err
:=
time
.
LoadLocation
(
"Asia/Shanghai"
)
require
.
NoError
(
t
,
err
)
extra
:=
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"fixed"
,
"quota_daily_reset_hour"
:
float64
(
9
),
"quota_reset_timezone"
:
"Asia/Shanghai"
,
}
ComputeQuotaResetAt
(
extra
)
resetAtStr
,
ok
:=
extra
[
"quota_daily_reset_at"
]
.
(
string
)
require
.
True
(
t
,
ok
)
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtStr
)
require
.
NoError
(
t
,
err
)
// In Shanghai timezone, the hour should be 9
assert
.
Equal
(
t
,
9
,
resetAt
.
In
(
tz
)
.
Hour
())
}
func
TestComputeQuotaResetAt_DefaultTimezone
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"fixed"
,
"quota_daily_reset_hour"
:
float64
(
12
),
}
ComputeQuotaResetAt
(
extra
)
resetAtStr
,
ok
:=
extra
[
"quota_daily_reset_at"
]
.
(
string
)
require
.
True
(
t
,
ok
)
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtStr
)
require
.
NoError
(
t
,
err
)
// Default timezone is UTC
assert
.
Equal
(
t
,
12
,
resetAt
.
UTC
()
.
Hour
())
}
func
TestComputeQuotaResetAt_InvalidHour_ClampedToZero
(
t
*
testing
.
T
)
{
extra
:=
map
[
string
]
any
{
"quota_daily_reset_mode"
:
"fixed"
,
"quota_daily_reset_hour"
:
float64
(
99
),
"quota_reset_timezone"
:
"UTC"
,
}
ComputeQuotaResetAt
(
extra
)
resetAtStr
,
ok
:=
extra
[
"quota_daily_reset_at"
]
.
(
string
)
require
.
True
(
t
,
ok
)
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtStr
)
require
.
NoError
(
t
,
err
)
// Invalid hour → clamped to 0
assert
.
Equal
(
t
,
0
,
resetAt
.
UTC
()
.
Hour
())
}
backend/internal/service/account_usage_service.go
View file @
8f0ea7a0
...
...
@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"log"
"log/slog"
"math/rand/v2"
"net/http"
"strings"
...
...
@@ -100,6 +101,7 @@ type antigravityUsageCache struct {
const
(
apiCacheTTL
=
3
*
time
.
Minute
apiErrorCacheTTL
=
1
*
time
.
Minute
// 负缓存 TTL:429 等错误缓存 1 分钟
antigravityErrorTTL
=
1
*
time
.
Minute
// Antigravity 错误缓存 TTL(可恢复错误)
apiQueryMaxJitter
=
800
*
time
.
Millisecond
// 用量查询最大随机延迟
windowStatsCacheTTL
=
1
*
time
.
Minute
openAIProbeCacheTTL
=
10
*
time
.
Minute
...
...
@@ -111,7 +113,8 @@ type UsageCache struct {
apiCache
sync
.
Map
// accountID -> *apiUsageCache
windowStatsCache
sync
.
Map
// accountID -> *windowStatsCache
antigravityCache
sync
.
Map
// accountID -> *antigravityUsageCache
apiFlight
singleflight
.
Group
// 防止同一账号的并发请求击穿缓存
apiFlight
singleflight
.
Group
// 防止同一账号的并发请求击穿缓存(Anthropic)
antigravityFlight
singleflight
.
Group
// 防止同一 Antigravity 账号的并发请求击穿缓存
openAIProbeCache
sync
.
Map
// accountID -> time.Time
}
...
...
@@ -149,6 +152,18 @@ type AntigravityModelQuota struct {
ResetTime
string
`json:"reset_time"`
// 重置时间 ISO8601
}
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
type
AntigravityModelDetail
struct
{
DisplayName
string
`json:"display_name,omitempty"`
SupportsImages
*
bool
`json:"supports_images,omitempty"`
SupportsThinking
*
bool
`json:"supports_thinking,omitempty"`
ThinkingBudget
*
int
`json:"thinking_budget,omitempty"`
Recommended
*
bool
`json:"recommended,omitempty"`
MaxTokens
*
int
`json:"max_tokens,omitempty"`
MaxOutputTokens
*
int
`json:"max_output_tokens,omitempty"`
SupportedMimeTypes
map
[
string
]
bool
`json:"supported_mime_types,omitempty"`
}
// UsageInfo 账号使用量信息
type
UsageInfo
struct
{
UpdatedAt
*
time
.
Time
`json:"updated_at,omitempty"`
// 更新时间
...
...
@@ -164,6 +179,33 @@ type UsageInfo struct {
// Antigravity 多模型配额
AntigravityQuota
map
[
string
]
*
AntigravityModelQuota
`json:"antigravity_quota,omitempty"`
// Antigravity 账号级信息
SubscriptionTier
string
`json:"subscription_tier,omitempty"`
// 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
SubscriptionTierRaw
string
`json:"subscription_tier_raw,omitempty"`
// 上游原始订阅等级名称
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
AntigravityQuotaDetails
map
[
string
]
*
AntigravityModelDetail
`json:"antigravity_quota_details,omitempty"`
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
ModelForwardingRules
map
[
string
]
string
`json:"model_forwarding_rules,omitempty"`
// Antigravity 账号是否被上游禁止 (HTTP 403)
IsForbidden
bool
`json:"is_forbidden,omitempty"`
ForbiddenReason
string
`json:"forbidden_reason,omitempty"`
ForbiddenType
string
`json:"forbidden_type,omitempty"`
// "validation" / "violation" / "forbidden"
ValidationURL
string
`json:"validation_url,omitempty"`
// 验证/申诉链接
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
NeedsVerify
bool
`json:"needs_verify,omitempty"`
// 需要人工验证(forbidden_type=validation)
IsBanned
bool
`json:"is_banned,omitempty"`
// 账号被封(forbidden_type=violation)
NeedsReauth
bool
`json:"needs_reauth,omitempty"`
// token 失效需重新授权(401)
// 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error
ErrorCode
string
`json:"error_code,omitempty"`
// 获取 usage 时的错误信息(降级返回,而非 500)
Error
string
`json:"error,omitempty"`
}
// ClaudeUsageResponse Anthropic API返回的usage结构
...
...
@@ -648,10 +690,11 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
return
&
UsageInfo
{
UpdatedAt
:
&
now
},
nil
}
// 1. 检查缓存
(10 分钟)
// 1. 检查缓存
if
cached
,
ok
:=
s
.
cache
.
antigravityCache
.
Load
(
account
.
ID
);
ok
{
if
cache
,
ok
:=
cached
.
(
*
antigravityUsageCache
);
ok
&&
time
.
Since
(
cache
.
timestamp
)
<
apiCacheTTL
{
// 重新计算 RemainingSeconds
if
cache
,
ok
:=
cached
.
(
*
antigravityUsageCache
);
ok
{
ttl
:=
antigravityCacheTTL
(
cache
.
usageInfo
)
if
time
.
Since
(
cache
.
timestamp
)
<
ttl
{
usage
:=
cache
.
usageInfo
if
usage
.
FiveHour
!=
nil
&&
usage
.
FiveHour
.
ResetsAt
!=
nil
{
usage
.
FiveHour
.
RemainingSeconds
=
int
(
time
.
Until
(
*
usage
.
FiveHour
.
ResetsAt
)
.
Seconds
())
...
...
@@ -659,23 +702,145 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
return
usage
,
nil
}
}
}
// 2. singleflight 防止并发击穿
flightKey
:=
fmt
.
Sprintf
(
"ag-usage:%d"
,
account
.
ID
)
result
,
flightErr
,
_
:=
s
.
cache
.
antigravityFlight
.
Do
(
flightKey
,
func
()
(
any
,
error
)
{
// 再次检查缓存(等待期间可能已被填充)
if
cached
,
ok
:=
s
.
cache
.
antigravityCache
.
Load
(
account
.
ID
);
ok
{
if
cache
,
ok
:=
cached
.
(
*
antigravityUsageCache
);
ok
{
ttl
:=
antigravityCacheTTL
(
cache
.
usageInfo
)
if
time
.
Since
(
cache
.
timestamp
)
<
ttl
{
usage
:=
cache
.
usageInfo
// 重新计算 RemainingSeconds,避免返回过时的剩余秒数
recalcAntigravityRemainingSeconds
(
usage
)
return
usage
,
nil
}
}
}
// 2. 获取代理 URL
proxyURL
:=
s
.
antigravityQuotaFetcher
.
GetProxyURL
(
ctx
,
account
)
// 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败
fetchCtx
,
fetchCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
fetchCancel
()
// 3. 调用 API 获取额度
r
esult
,
err
:=
s
.
antigravityQuotaFetcher
.
FetchQuota
(
c
tx
,
account
,
proxyURL
)
proxyURL
:=
s
.
antigravityQuotaFetcher
.
GetProxyURL
(
fetchCtx
,
account
)
fetchR
esult
,
err
:=
s
.
antigravityQuotaFetcher
.
FetchQuota
(
fetchC
tx
,
account
,
proxyURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"fetch antigravity quota failed: %w"
,
err
)
degraded
:=
buildAntigravityDegradedUsage
(
err
)
enrichUsageWithAccountError
(
degraded
,
account
)
s
.
cache
.
antigravityCache
.
Store
(
account
.
ID
,
&
antigravityUsageCache
{
usageInfo
:
degraded
,
timestamp
:
time
.
Now
(),
})
return
degraded
,
nil
}
// 4. 缓存结果
enrichUsageWithAccountError
(
fetchResult
.
UsageInfo
,
account
)
s
.
cache
.
antigravityCache
.
Store
(
account
.
ID
,
&
antigravityUsageCache
{
usageInfo
:
r
esult
.
UsageInfo
,
usageInfo
:
fetchR
esult
.
UsageInfo
,
timestamp
:
time
.
Now
(),
})
return
fetchResult
.
UsageInfo
,
nil
})
if
flightErr
!=
nil
{
return
nil
,
flightErr
}
usage
,
ok
:=
result
.
(
*
UsageInfo
)
if
!
ok
||
usage
==
nil
{
now
:=
time
.
Now
()
return
&
UsageInfo
{
UpdatedAt
:
&
now
},
nil
}
return
usage
,
nil
}
return
result
.
UsageInfo
,
nil
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
func
recalcAntigravityRemainingSeconds
(
info
*
UsageInfo
)
{
if
info
==
nil
{
return
}
if
info
.
FiveHour
!=
nil
&&
info
.
FiveHour
.
ResetsAt
!=
nil
{
remaining
:=
int
(
time
.
Until
(
*
info
.
FiveHour
.
ResetsAt
)
.
Seconds
())
if
remaining
<
0
{
remaining
=
0
}
info
.
FiveHour
.
RemainingSeconds
=
remaining
}
}
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
// 403 forbidden 状态稳定,缓存与成功相同(3 分钟);
// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。
func
antigravityCacheTTL
(
info
*
UsageInfo
)
time
.
Duration
{
if
info
==
nil
{
return
antigravityErrorTTL
}
if
info
.
IsForbidden
{
return
apiCacheTTL
// 封号/验证状态不会很快变
}
if
info
.
ErrorCode
!=
""
||
info
.
Error
!=
""
{
return
antigravityErrorTTL
}
return
apiCacheTTL
}
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
func
buildAntigravityDegradedUsage
(
err
error
)
*
UsageInfo
{
now
:=
time
.
Now
()
errMsg
:=
fmt
.
Sprintf
(
"usage API error: %v"
,
err
)
slog
.
Warn
(
"antigravity usage fetch failed, returning degraded response"
,
"error"
,
err
)
info
:=
&
UsageInfo
{
UpdatedAt
:
&
now
,
Error
:
errMsg
,
}
// 从错误信息推断 error_code 和状态标记
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
errStr
:=
err
.
Error
()
switch
{
case
strings
.
Contains
(
errStr
,
"HTTP 401"
)
||
strings
.
Contains
(
errStr
,
"UNAUTHENTICATED"
)
||
strings
.
Contains
(
errStr
,
"invalid_grant"
)
:
info
.
ErrorCode
=
errorCodeUnauthenticated
info
.
NeedsReauth
=
true
case
strings
.
Contains
(
errStr
,
"HTTP 429"
)
:
info
.
ErrorCode
=
errorCodeRateLimited
default
:
info
.
ErrorCode
=
errorCodeNetworkError
}
return
info
}
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error,
//
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
//
// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401,
//
// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。
func
enrichUsageWithAccountError
(
info
*
UsageInfo
,
account
*
Account
)
{
if
info
==
nil
||
account
==
nil
||
account
.
Status
!=
StatusError
{
return
}
msg
:=
strings
.
ToLower
(
account
.
ErrorMessage
)
if
!
strings
.
Contains
(
msg
,
"403"
)
&&
!
strings
.
Contains
(
msg
,
"forbidden"
)
&&
!
strings
.
Contains
(
msg
,
"violation"
)
&&
!
strings
.
Contains
(
msg
,
"validation"
)
{
return
}
fbType
:=
classifyForbiddenType
(
account
.
ErrorMessage
)
info
.
IsForbidden
=
true
info
.
ForbiddenType
=
fbType
info
.
ForbiddenReason
=
account
.
ErrorMessage
info
.
NeedsVerify
=
fbType
==
forbiddenTypeValidation
info
.
IsBanned
=
fbType
==
forbiddenTypeViolation
info
.
ValidationURL
=
extractValidationURL
(
account
.
ErrorMessage
)
info
.
ErrorCode
=
errorCodeForbidden
info
.
NeedsReauth
=
false
}
// addWindowStats 为 usage 数据添加窗口期统计
...
...
backend/internal/service/account_wildcard_test.go
View file @
8f0ea7a0
...
...
@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
}
}
func
TestMatchWildcardMapping
(
t
*
testing
.
T
)
{
func
TestMatchWildcardMapping
Result
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
mapping
map
[
string
]
string
requestedModel
string
expected
string
matched
bool
}{
// 精确匹配优先于通配符
{
...
...
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel
:
"claude-sonnet-4-5"
,
expected
:
"claude-sonnet-4-5-exact"
,
matched
:
true
,
},
// 最长通配符优先
...
...
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel
:
"claude-sonnet-4-5"
,
expected
:
"claude-sonnet-4-series"
,
matched
:
true
,
},
// 单个通配符
...
...
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel
:
"claude-opus-4-5"
,
expected
:
"claude-mapped"
,
matched
:
true
,
},
// 无匹配返回原始模型
...
...
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel
:
"gemini-3-flash"
,
expected
:
"gemini-3-flash"
,
matched
:
false
,
},
// 空映射返回原始模型
...
...
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
mapping
:
map
[
string
]
string
{},
requestedModel
:
"claude-sonnet-4-5"
,
expected
:
"claude-sonnet-4-5"
,
matched
:
false
,
},
// Gemini 模型映射
...
...
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel
:
"gemini-3-flash-preview"
,
expected
:
"gemini-3-pro-high"
,
matched
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
matchWildcardMapping
(
tt
.
mapping
,
tt
.
requestedModel
)
if
result
!=
tt
.
expected
{
t
.
Errorf
(
"matchWildcardMapping(%v, %q) = %q, want %q"
,
tt
.
mapping
,
tt
.
requestedModel
,
result
,
tt
.
expect
ed
)
result
,
matched
:=
matchWildcardMapping
Result
(
tt
.
mapping
,
tt
.
requestedModel
)
if
result
!=
tt
.
expected
||
matched
!=
tt
.
matched
{
t
.
Errorf
(
"matchWildcardMapping
Result
(%v, %q) =
(
%q,
%v),
want
(
%q
, %v)
"
,
tt
.
mapping
,
tt
.
requestedModel
,
result
,
matched
,
tt
.
expected
,
tt
.
match
ed
)
}
})
}
...
...
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
}
}
func
TestAccountResolveMappedModel
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
credentials
map
[
string
]
any
requestedModel
string
expectedModel
string
expectedMatch
bool
}{
{
name
:
"no mapping reports unmatched"
,
credentials
:
nil
,
requestedModel
:
"gpt-5.4"
,
expectedModel
:
"gpt-5.4"
,
expectedMatch
:
false
,
},
{
name
:
"exact passthrough mapping still counts as matched"
,
credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-5.4"
:
"gpt-5.4"
,
},
},
requestedModel
:
"gpt-5.4"
,
expectedModel
:
"gpt-5.4"
,
expectedMatch
:
true
,
},
{
name
:
"wildcard passthrough mapping still counts as matched"
,
credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-*"
:
"gpt-5.4"
,
},
},
requestedModel
:
"gpt-5.4"
,
expectedModel
:
"gpt-5.4"
,
expectedMatch
:
true
,
},
{
name
:
"missing mapping reports unmatched"
,
credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gpt-5.2"
:
"gpt-5.2"
,
},
},
requestedModel
:
"gpt-5.4"
,
expectedModel
:
"gpt-5.4"
,
expectedMatch
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Credentials
:
tt
.
credentials
,
}
mappedModel
,
matched
:=
account
.
ResolveMappedModel
(
tt
.
requestedModel
)
if
mappedModel
!=
tt
.
expectedModel
||
matched
!=
tt
.
expectedMatch
{
t
.
Fatalf
(
"ResolveMappedModel(%q) = (%q, %v), want (%q, %v)"
,
tt
.
requestedModel
,
mappedModel
,
matched
,
tt
.
expectedModel
,
tt
.
expectedMatch
)
}
})
}
}
func
TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAntigravity
,
...
...
backend/internal/service/admin_service.go
View file @
8f0ea7a0
...
...
@@ -1462,6 +1462,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
Status
:
StatusActive
,
Schedulable
:
true
,
}
// 预计算固定时间重置的下次重置时间
if
account
.
Extra
!=
nil
{
if
err
:=
ValidateQuotaResetConfig
(
account
.
Extra
);
err
!=
nil
{
return
nil
,
err
}
ComputeQuotaResetAt
(
account
.
Extra
)
}
if
input
.
ExpiresAt
!=
nil
&&
*
input
.
ExpiresAt
>
0
{
expiresAt
:=
time
.
Unix
(
*
input
.
ExpiresAt
,
0
)
account
.
ExpiresAt
=
&
expiresAt
...
...
@@ -1535,6 +1542,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
}
}
account
.
Extra
=
input
.
Extra
// 校验并预计算固定时间重置的下次重置时间
if
err
:=
ValidateQuotaResetConfig
(
account
.
Extra
);
err
!=
nil
{
return
nil
,
err
}
ComputeQuotaResetAt
(
account
.
Extra
)
}
if
input
.
ProxyID
!=
nil
{
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
...
...
backend/internal/service/antigravity_quota_fetcher.go
View file @
8f0ea7a0
...
...
@@ -2,12 +2,29 @@ package service
import
(
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
const
(
forbiddenTypeValidation
=
"validation"
forbiddenTypeViolation
=
"violation"
forbiddenTypeForbidden
=
"forbidden"
// 机器可读的错误码
errorCodeForbidden
=
"forbidden"
errorCodeUnauthenticated
=
"unauthenticated"
errorCodeRateLimited
=
"rate_limited"
errorCodeNetworkError
=
"network_error"
)
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
type
AntigravityQuotaFetcher
struct
{
proxyRepo
ProxyRepository
...
...
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
// 调用 API 获取配额
modelsResp
,
modelsRaw
,
err
:=
client
.
FetchAvailableModels
(
ctx
,
accessToken
,
projectID
)
if
err
!=
nil
{
// 403 Forbidden: 不报错,返回 is_forbidden 标记
var
forbiddenErr
*
antigravity
.
ForbiddenError
if
errors
.
As
(
err
,
&
forbiddenErr
)
{
now
:=
time
.
Now
()
fbType
:=
classifyForbiddenType
(
forbiddenErr
.
Body
)
return
&
QuotaResult
{
UsageInfo
:
&
UsageInfo
{
UpdatedAt
:
&
now
,
IsForbidden
:
true
,
ForbiddenReason
:
forbiddenErr
.
Body
,
ForbiddenType
:
fbType
,
ValidationURL
:
extractValidationURL
(
forbiddenErr
.
Body
),
NeedsVerify
:
fbType
==
forbiddenTypeValidation
,
IsBanned
:
fbType
==
forbiddenTypeViolation
,
ErrorCode
:
errorCodeForbidden
,
},
},
nil
}
return
nil
,
err
}
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
tierRaw
,
tierNormalized
:=
f
.
fetchSubscriptionTier
(
ctx
,
client
,
accessToken
)
// 转换为 UsageInfo
usageInfo
:=
f
.
buildUsageInfo
(
modelsResp
)
usageInfo
:=
f
.
buildUsageInfo
(
modelsResp
,
tierRaw
,
tierNormalized
)
return
&
QuotaResult
{
UsageInfo
:
usageInfo
,
...
...
@@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
},
nil
}
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
func
(
f
*
AntigravityQuotaFetcher
)
fetchSubscriptionTier
(
ctx
context
.
Context
,
client
*
antigravity
.
Client
,
accessToken
string
)
(
raw
,
normalized
string
)
{
loadResp
,
_
,
err
:=
client
.
LoadCodeAssist
(
ctx
,
accessToken
)
if
err
!=
nil
{
slog
.
Warn
(
"failed to fetch subscription tier"
,
"error"
,
err
)
return
""
,
""
}
if
loadResp
==
nil
{
return
""
,
""
}
raw
=
loadResp
.
GetTier
()
// 已有方法:paidTier > currentTier
normalized
=
normalizeTier
(
raw
)
return
raw
,
normalized
}
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
func
normalizeTier
(
raw
string
)
string
{
if
raw
==
""
{
return
""
}
lower
:=
strings
.
ToLower
(
raw
)
switch
{
case
strings
.
Contains
(
lower
,
"ultra"
)
:
return
"ULTRA"
case
strings
.
Contains
(
lower
,
"pro"
)
:
return
"PRO"
case
strings
.
Contains
(
lower
,
"free"
)
:
return
"FREE"
default
:
return
"UNKNOWN"
}
}
// buildUsageInfo 将 API 响应转换为 UsageInfo
func
(
f
*
AntigravityQuotaFetcher
)
buildUsageInfo
(
modelsResp
*
antigravity
.
FetchAvailableModelsResponse
)
*
UsageInfo
{
func
(
f
*
AntigravityQuotaFetcher
)
buildUsageInfo
(
modelsResp
*
antigravity
.
FetchAvailableModelsResponse
,
tierRaw
,
tierNormalized
string
)
*
UsageInfo
{
now
:=
time
.
Now
()
info
:=
&
UsageInfo
{
UpdatedAt
:
&
now
,
AntigravityQuota
:
make
(
map
[
string
]
*
AntigravityModelQuota
),
AntigravityQuotaDetails
:
make
(
map
[
string
]
*
AntigravityModelDetail
),
SubscriptionTier
:
tierNormalized
,
SubscriptionTierRaw
:
tierRaw
,
}
// 遍历所有模型,填充 AntigravityQuota
// 遍历所有模型,填充 AntigravityQuota
和 AntigravityQuotaDetails
for
modelName
,
modelInfo
:=
range
modelsResp
.
Models
{
if
modelInfo
.
QuotaInfo
==
nil
{
continue
...
...
@@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
Utilization
:
utilization
,
ResetTime
:
modelInfo
.
QuotaInfo
.
ResetTime
,
}
// 填充模型详细能力信息
detail
:=
&
AntigravityModelDetail
{
DisplayName
:
modelInfo
.
DisplayName
,
SupportsImages
:
modelInfo
.
SupportsImages
,
SupportsThinking
:
modelInfo
.
SupportsThinking
,
ThinkingBudget
:
modelInfo
.
ThinkingBudget
,
Recommended
:
modelInfo
.
Recommended
,
MaxTokens
:
modelInfo
.
MaxTokens
,
MaxOutputTokens
:
modelInfo
.
MaxOutputTokens
,
SupportedMimeTypes
:
modelInfo
.
SupportedMimeTypes
,
}
info
.
AntigravityQuotaDetails
[
modelName
]
=
detail
}
// 废弃模型转发规则
if
len
(
modelsResp
.
DeprecatedModelIDs
)
>
0
{
info
.
ModelForwardingRules
=
make
(
map
[
string
]
string
,
len
(
modelsResp
.
DeprecatedModelIDs
))
for
oldID
,
deprecated
:=
range
modelsResp
.
DeprecatedModelIDs
{
info
.
ModelForwardingRules
[
oldID
]
=
deprecated
.
NewModelID
}
}
// 同时设置 FiveHour 用于兼容展示(取主要模型)
...
...
@@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
}
return
proxy
.
URL
()
}
// classifyForbiddenType 根据 403 响应体判断禁止类型
func
classifyForbiddenType
(
body
string
)
string
{
lower
:=
strings
.
ToLower
(
body
)
switch
{
case
strings
.
Contains
(
lower
,
"validation_required"
)
||
strings
.
Contains
(
lower
,
"verify your account"
)
||
strings
.
Contains
(
lower
,
"validation_url"
)
:
return
forbiddenTypeValidation
case
strings
.
Contains
(
lower
,
"terms of service"
)
||
strings
.
Contains
(
lower
,
"violation"
)
:
return
forbiddenTypeViolation
default
:
return
forbiddenTypeForbidden
}
}
// urlPattern 用于从 403 响应体中提取 URL(降级方案)
var
urlPattern
=
regexp
.
MustCompile
(
`https://[^\s"'\\]+`
)
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
func
extractValidationURL
(
body
string
)
string
{
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
var
parsed
struct
{
Error
struct
{
Details
[]
struct
{
Metadata
map
[
string
]
string
`json:"metadata"`
}
`json:"details"`
}
`json:"error"`
}
if
json
.
Unmarshal
([]
byte
(
body
),
&
parsed
)
==
nil
{
for
_
,
detail
:=
range
parsed
.
Error
.
Details
{
if
u
:=
detail
.
Metadata
[
"validation_url"
];
u
!=
""
{
return
u
}
if
u
:=
detail
.
Metadata
[
"appeal_url"
];
u
!=
""
{
return
u
}
}
}
// 2. 降级:正则匹配 URL
lower
:=
strings
.
ToLower
(
body
)
if
!
strings
.
Contains
(
lower
,
"validation"
)
&&
!
strings
.
Contains
(
lower
,
"verify"
)
&&
!
strings
.
Contains
(
lower
,
"appeal"
)
{
return
""
}
// 先解码常见转义再匹配
normalized
:=
strings
.
ReplaceAll
(
body
,
`\u0026`
,
"&"
)
if
m
:=
urlPattern
.
FindString
(
normalized
);
m
!=
""
{
return
m
}
return
""
}
backend/internal/service/antigravity_quota_fetcher_test.go
0 → 100644
View file @
8f0ea7a0
//go:build unit
package
service
import
(
"errors"
"testing"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// ---------------------------------------------------------------------------
// normalizeTier
// ---------------------------------------------------------------------------
func
TestNormalizeTier
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
raw
string
expected
string
}{
{
name
:
"empty string"
,
raw
:
""
,
expected
:
""
},
{
name
:
"free-tier"
,
raw
:
"free-tier"
,
expected
:
"FREE"
},
{
name
:
"g1-pro-tier"
,
raw
:
"g1-pro-tier"
,
expected
:
"PRO"
},
{
name
:
"g1-ultra-tier"
,
raw
:
"g1-ultra-tier"
,
expected
:
"ULTRA"
},
{
name
:
"unknown-something"
,
raw
:
"unknown-something"
,
expected
:
"UNKNOWN"
},
{
name
:
"Google AI Pro contains pro keyword"
,
raw
:
"Google AI Pro"
,
expected
:
"PRO"
},
{
name
:
"case insensitive FREE"
,
raw
:
"FREE-TIER"
,
expected
:
"FREE"
},
{
name
:
"case insensitive Ultra"
,
raw
:
"Ultra Plan"
,
expected
:
"ULTRA"
},
{
name
:
"arbitrary unrecognized string"
,
raw
:
"enterprise-custom"
,
expected
:
"UNKNOWN"
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
normalizeTier
(
tt
.
raw
)
require
.
Equal
(
t
,
tt
.
expected
,
got
,
"normalizeTier(%q)"
,
tt
.
raw
)
})
}
}
// ---------------------------------------------------------------------------
// buildUsageInfo
// ---------------------------------------------------------------------------
func
aqfBoolPtr
(
v
bool
)
*
bool
{
return
&
v
}
func
aqfIntPtr
(
v
int
)
*
int
{
return
&
v
}
func
TestBuildUsageInfo_BasicModels
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"claude-sonnet-4-20250514"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.75
,
ResetTime
:
"2026-03-08T12:00:00Z"
,
},
DisplayName
:
"Claude Sonnet 4"
,
SupportsImages
:
aqfBoolPtr
(
true
),
SupportsThinking
:
aqfBoolPtr
(
false
),
ThinkingBudget
:
aqfIntPtr
(
0
),
Recommended
:
aqfBoolPtr
(
true
),
MaxTokens
:
aqfIntPtr
(
200000
),
MaxOutputTokens
:
aqfIntPtr
(
16384
),
SupportedMimeTypes
:
map
[
string
]
bool
{
"image/png"
:
true
,
"image/jpeg"
:
true
,
},
},
"gemini-2.5-pro"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.50
,
ResetTime
:
"2026-03-08T15:00:00Z"
,
},
DisplayName
:
"Gemini 2.5 Pro"
,
MaxTokens
:
aqfIntPtr
(
1000000
),
MaxOutputTokens
:
aqfIntPtr
(
65536
),
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
"g1-pro-tier"
,
"PRO"
)
// 基本字段
require
.
NotNil
(
t
,
info
.
UpdatedAt
,
"UpdatedAt should be set"
)
require
.
Equal
(
t
,
"PRO"
,
info
.
SubscriptionTier
)
require
.
Equal
(
t
,
"g1-pro-tier"
,
info
.
SubscriptionTierRaw
)
// AntigravityQuota
require
.
Len
(
t
,
info
.
AntigravityQuota
,
2
)
sonnetQuota
:=
info
.
AntigravityQuota
[
"claude-sonnet-4-20250514"
]
require
.
NotNil
(
t
,
sonnetQuota
)
require
.
Equal
(
t
,
25
,
sonnetQuota
.
Utilization
)
// (1 - 0.75) * 100 = 25
require
.
Equal
(
t
,
"2026-03-08T12:00:00Z"
,
sonnetQuota
.
ResetTime
)
geminiQuota
:=
info
.
AntigravityQuota
[
"gemini-2.5-pro"
]
require
.
NotNil
(
t
,
geminiQuota
)
require
.
Equal
(
t
,
50
,
geminiQuota
.
Utilization
)
// (1 - 0.50) * 100 = 50
require
.
Equal
(
t
,
"2026-03-08T15:00:00Z"
,
geminiQuota
.
ResetTime
)
// AntigravityQuotaDetails
require
.
Len
(
t
,
info
.
AntigravityQuotaDetails
,
2
)
sonnetDetail
:=
info
.
AntigravityQuotaDetails
[
"claude-sonnet-4-20250514"
]
require
.
NotNil
(
t
,
sonnetDetail
)
require
.
Equal
(
t
,
"Claude Sonnet 4"
,
sonnetDetail
.
DisplayName
)
require
.
Equal
(
t
,
aqfBoolPtr
(
true
),
sonnetDetail
.
SupportsImages
)
require
.
Equal
(
t
,
aqfBoolPtr
(
false
),
sonnetDetail
.
SupportsThinking
)
require
.
Equal
(
t
,
aqfIntPtr
(
0
),
sonnetDetail
.
ThinkingBudget
)
require
.
Equal
(
t
,
aqfBoolPtr
(
true
),
sonnetDetail
.
Recommended
)
require
.
Equal
(
t
,
aqfIntPtr
(
200000
),
sonnetDetail
.
MaxTokens
)
require
.
Equal
(
t
,
aqfIntPtr
(
16384
),
sonnetDetail
.
MaxOutputTokens
)
require
.
Equal
(
t
,
map
[
string
]
bool
{
"image/png"
:
true
,
"image/jpeg"
:
true
},
sonnetDetail
.
SupportedMimeTypes
)
geminiDetail
:=
info
.
AntigravityQuotaDetails
[
"gemini-2.5-pro"
]
require
.
NotNil
(
t
,
geminiDetail
)
require
.
Equal
(
t
,
"Gemini 2.5 Pro"
,
geminiDetail
.
DisplayName
)
require
.
Nil
(
t
,
geminiDetail
.
SupportsImages
)
require
.
Nil
(
t
,
geminiDetail
.
SupportsThinking
)
require
.
Equal
(
t
,
aqfIntPtr
(
1000000
),
geminiDetail
.
MaxTokens
)
require
.
Equal
(
t
,
aqfIntPtr
(
65536
),
geminiDetail
.
MaxOutputTokens
)
}
func
TestBuildUsageInfo_DeprecatedModels
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"claude-sonnet-4-20250514"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
1.0
,
},
},
},
DeprecatedModelIDs
:
map
[
string
]
antigravity
.
DeprecatedModelInfo
{
"claude-3-sonnet-20240229"
:
{
NewModelID
:
"claude-sonnet-4-20250514"
},
"claude-3-haiku-20240307"
:
{
NewModelID
:
"claude-haiku-3.5-latest"
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
require
.
Len
(
t
,
info
.
ModelForwardingRules
,
2
)
require
.
Equal
(
t
,
"claude-sonnet-4-20250514"
,
info
.
ModelForwardingRules
[
"claude-3-sonnet-20240229"
])
require
.
Equal
(
t
,
"claude-haiku-3.5-latest"
,
info
.
ModelForwardingRules
[
"claude-3-haiku-20240307"
])
}
func
TestBuildUsageInfo_NoDeprecatedModels
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"some-model"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.9
},
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
require
.
Nil
(
t
,
info
.
ModelForwardingRules
,
"ModelForwardingRules should be nil when no deprecated models"
)
}
func
TestBuildUsageInfo_EmptyModels
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
require
.
NotNil
(
t
,
info
)
require
.
NotNil
(
t
,
info
.
AntigravityQuota
)
require
.
Empty
(
t
,
info
.
AntigravityQuota
)
require
.
NotNil
(
t
,
info
.
AntigravityQuotaDetails
)
require
.
Empty
(
t
,
info
.
AntigravityQuotaDetails
)
require
.
Nil
(
t
,
info
.
FiveHour
,
"FiveHour should be nil when no priority model exists"
)
}
func
TestBuildUsageInfo_ModelWithNilQuotaInfo
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"model-without-quota"
:
{
DisplayName
:
"No Quota Model"
,
// QuotaInfo is nil
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
require
.
NotNil
(
t
,
info
)
require
.
Empty
(
t
,
info
.
AntigravityQuota
,
"models with nil QuotaInfo should be skipped"
)
require
.
Empty
(
t
,
info
.
AntigravityQuotaDetails
,
"models with nil QuotaInfo should be skipped from details too"
)
}
func
TestBuildUsageInfo_FiveHourPriorityOrder
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
// When the first priority model exists, it should be used for FiveHour
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"gemini-2.5-pro"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.40
,
ResetTime
:
"2026-03-08T18:00:00Z"
,
},
},
"claude-sonnet-4-20250514"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.80
,
ResetTime
:
"2026-03-08T12:00:00Z"
,
},
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
require
.
NotNil
(
t
,
info
.
FiveHour
,
"FiveHour should be set when a priority model exists"
)
// claude-sonnet-4-20250514 is first in priority list, so it should be used
expectedUtilization
:=
(
1.0
-
0.80
)
*
100
// 20
require
.
InDelta
(
t
,
expectedUtilization
,
info
.
FiveHour
.
Utilization
,
0.01
)
require
.
NotNil
(
t
,
info
.
FiveHour
.
ResetsAt
,
"ResetsAt should be parsed from ResetTime"
)
}
func
TestBuildUsageInfo_FiveHourFallbackToClaude4
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"claude-sonnet-4"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.60
,
ResetTime
:
"2026-03-08T14:00:00Z"
,
},
},
"gemini-2.5-pro"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.30
,
},
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
require
.
NotNil
(
t
,
info
.
FiveHour
)
expectedUtilization
:=
(
1.0
-
0.60
)
*
100
// 40
require
.
InDelta
(
t
,
expectedUtilization
,
info
.
FiveHour
.
Utilization
,
0.01
)
}
func
TestBuildUsageInfo_FiveHourFallbackToGemini
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
// Only gemini-2.5-pro exists (third in priority list)
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"gemini-2.5-pro"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.30
,
},
},
"other-model"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.90
,
},
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
require
.
NotNil
(
t
,
info
.
FiveHour
)
expectedUtilization
:=
(
1.0
-
0.30
)
*
100
// 70
require
.
InDelta
(
t
,
expectedUtilization
,
info
.
FiveHour
.
Utilization
,
0.01
)
}
func
TestBuildUsageInfo_FiveHourNoPriorityModel
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
// None of the priority models exist
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"some-other-model"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.50
,
},
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
require
.
Nil
(
t
,
info
.
FiveHour
,
"FiveHour should be nil when no priority model exists"
)
}
func
TestBuildUsageInfo_FiveHourWithEmptyResetTime
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"claude-sonnet-4-20250514"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.50
,
ResetTime
:
""
,
// empty reset time
},
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
require
.
NotNil
(
t
,
info
.
FiveHour
)
require
.
Nil
(
t
,
info
.
FiveHour
.
ResetsAt
,
"ResetsAt should be nil when ResetTime is empty"
)
require
.
Equal
(
t
,
0
,
info
.
FiveHour
.
RemainingSeconds
)
}
func
TestBuildUsageInfo_FullUtilization
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"claude-sonnet-4-20250514"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
0.0
,
// fully used
ResetTime
:
"2026-03-08T12:00:00Z"
,
},
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
quota
:=
info
.
AntigravityQuota
[
"claude-sonnet-4-20250514"
]
require
.
NotNil
(
t
,
quota
)
require
.
Equal
(
t
,
100
,
quota
.
Utilization
)
}
func
TestBuildUsageInfo_ZeroUtilization
(
t
*
testing
.
T
)
{
fetcher
:=
&
AntigravityQuotaFetcher
{}
modelsResp
:=
&
antigravity
.
FetchAvailableModelsResponse
{
Models
:
map
[
string
]
antigravity
.
ModelInfo
{
"claude-sonnet-4-20250514"
:
{
QuotaInfo
:
&
antigravity
.
ModelQuotaInfo
{
RemainingFraction
:
1.0
,
// fully available
},
},
},
}
info
:=
fetcher
.
buildUsageInfo
(
modelsResp
,
""
,
""
)
quota
:=
info
.
AntigravityQuota
[
"claude-sonnet-4-20250514"
]
require
.
NotNil
(
t
,
quota
)
require
.
Equal
(
t
,
0
,
quota
.
Utilization
)
}
func
TestFetchQuota_ForbiddenReturnsIsForbidden
(
t
*
testing
.
T
)
{
// 模拟 FetchQuota 遇到 403 时的行为:
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
forbiddenErr
:=
&
antigravity
.
ForbiddenError
{
StatusCode
:
403
,
Body
:
"Access denied"
,
}
// 验证 ForbiddenError 满足 errors.As
var
target
*
antigravity
.
ForbiddenError
require
.
True
(
t
,
errors
.
As
(
forbiddenErr
,
&
target
))
require
.
Equal
(
t
,
403
,
target
.
StatusCode
)
require
.
Equal
(
t
,
"Access denied"
,
target
.
Body
)
require
.
Contains
(
t
,
forbiddenErr
.
Error
(),
"403"
)
}
// ---------------------------------------------------------------------------
// classifyForbiddenType
// ---------------------------------------------------------------------------
func
TestClassifyForbiddenType
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
body
string
expected
string
}{
{
name
:
"VALIDATION_REQUIRED keyword"
,
body
:
`{"error":{"message":"VALIDATION_REQUIRED"}}`
,
expected
:
"validation"
,
},
{
name
:
"verify your account"
,
body
:
`Please verify your account to continue`
,
expected
:
"validation"
,
},
{
name
:
"contains validation_url field"
,
body
:
`{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`
,
expected
:
"validation"
,
},
{
name
:
"terms of service violation"
,
body
:
`Your account has been suspended for Terms of Service violation`
,
expected
:
"violation"
,
},
{
name
:
"violation keyword"
,
body
:
`Account suspended due to policy violation`
,
expected
:
"violation"
,
},
{
name
:
"generic 403"
,
body
:
`Access denied`
,
expected
:
"forbidden"
,
},
{
name
:
"empty body"
,
body
:
""
,
expected
:
"forbidden"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
classifyForbiddenType
(
tt
.
body
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
// ---------------------------------------------------------------------------
// extractValidationURL
// ---------------------------------------------------------------------------
func
TestExtractValidationURL
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
body
string
expected
string
}{
{
name
:
"structured validation_url"
,
body
:
`{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`
,
expected
:
"https://accounts.google.com/verify?token=abc"
,
},
{
name
:
"structured appeal_url"
,
body
:
`{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`
,
expected
:
"https://support.google.com/appeal/123"
,
},
{
name
:
"validation_url takes priority over appeal_url"
,
body
:
`{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`
,
expected
:
"https://v.com"
,
},
{
name
:
"fallback regex with verify keyword"
,
body
:
`Please verify your account at https://accounts.google.com/verify`
,
expected
:
"https://accounts.google.com/verify"
,
},
{
name
:
"no URL in generic forbidden"
,
body
:
`Access denied`
,
expected
:
""
,
},
{
name
:
"empty body"
,
body
:
""
,
expected
:
""
,
},
{
name
:
"URL present but no validation keywords"
,
body
:
`Error at https://example.com/something`
,
expected
:
""
,
},
{
name
:
"unicode escaped ampersand"
,
body
:
`validation required: https://accounts.google.com/verify?a=1\u0026b=2`
,
expected
:
"https://accounts.google.com/verify?a=1&b=2"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
extractValidationURL
(
tt
.
body
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
backend/internal/service/auth_service.go
View file @
8f0ea7a0
...
...
@@ -1087,6 +1087,12 @@ type TokenPair struct {
ExpiresIn
int
`json:"expires_in"`
// Access Token有效期(秒)
}
// TokenPairWithUser extends TokenPair with user role for backend mode checks
type
TokenPairWithUser
struct
{
TokenPair
UserRole
string
}
// GenerateTokenPair 生成Access Token和Refresh Token对
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
func
(
s
*
AuthService
)
GenerateTokenPair
(
ctx
context
.
Context
,
user
*
User
,
familyID
string
)
(
*
TokenPair
,
error
)
{
...
...
@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
// RefreshTokenPair 使用Refresh Token刷新Token对
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
func
(
s
*
AuthService
)
RefreshTokenPair
(
ctx
context
.
Context
,
refreshToken
string
)
(
*
TokenPair
,
error
)
{
func
(
s
*
AuthService
)
RefreshTokenPair
(
ctx
context
.
Context
,
refreshToken
string
)
(
*
TokenPair
WithUser
,
error
)
{
// 检查 refreshTokenCache 是否可用
if
s
.
refreshTokenCache
==
nil
{
return
nil
,
ErrRefreshTokenInvalid
...
...
@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 生成新的Token对,保持同一个家族ID
return
s
.
GenerateTokenPair
(
ctx
,
user
,
data
.
FamilyID
)
pair
,
err
:=
s
.
GenerateTokenPair
(
ctx
,
user
,
data
.
FamilyID
)
if
err
!=
nil
{
return
nil
,
err
}
return
&
TokenPairWithUser
{
TokenPair
:
*
pair
,
UserRole
:
user
.
Role
,
},
nil
}
// RevokeRefreshToken 撤销单个Refresh Token
...
...
backend/internal/service/domain_constants.go
View file @
8f0ea7a0
...
...
@@ -33,8 +33,7 @@ const (
AccountTypeSetupToken
=
domain
.
AccountTypeSetupToken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
AccountTypeUpstream
=
domain
.
AccountTypeUpstream
// 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock
=
domain
.
AccountTypeBedrock
// AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
AccountTypeBedrockAPIKey
=
domain
.
AccountTypeBedrockAPIKey
// AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
AccountTypeBedrock
=
domain
.
AccountTypeBedrock
// AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
)
// Redeem type constants
...
...
@@ -221,6 +220,9 @@ const (
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403)
SettingKeyAllowUngroupedKeyScheduling
=
"allow_ungrouped_key_scheduling"
// SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录
SettingKeyBackendModeEnabled
=
"backend_mode_enabled"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
...
...
backend/internal/service/error_policy_test.go
View file @
8f0ea7a0
...
...
@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
expected
:
ErrorPolicyTempUnscheduled
,
},
{
name
:
"temp_unschedulable_401_second_hit_upgrades_to_none"
,
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
// second hit 仍然返回 TempUnscheduled。
name
:
"temp_unschedulable_401_second_hit_antigravity_stays_temp"
,
account
:
&
Account
{
ID
:
15
,
Type
:
AccountTypeOAuth
,
...
...
@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
},
statusCode
:
401
,
body
:
[]
byte
(
`unauthorized`
),
expected
:
ErrorPolicy
None
,
expected
:
ErrorPolicy
TempUnscheduled
,
},
{
name
:
"temp_unschedulable_body_miss_returns_none"
,
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment