Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
2bd288a6
Commit
2bd288a6
authored
Dec 29, 2025
by
song
Browse files
Merge branch 'main' into feature/antigravity_auth
parents
234e98f1
c01db6b1
Changes
40
Hide whitespace changes
Inline
Side-by-side
README_CN.md
View file @
2bd288a6
...
@@ -283,6 +283,16 @@ npm run dev
...
@@ -283,6 +283,16 @@ npm run dev
---
---
## 简易模式
简易模式适合个人开发者或内部团队快速使用,不依赖完整 SaaS 功能。
-
启用方式:设置环境变量
`RUN_MODE=simple`
-
功能差异:隐藏 SaaS 相关功能,跳过计费流程
-
安全注意事项:生产环境需同时设置
`SIMPLE_MODE_CONFIRM=true`
才允许启动
---
## 项目结构
## 项目结构
```
```
...
...
backend/cmd/server/main.go
View file @
2bd288a6
...
@@ -107,6 +107,14 @@ func runSetupServer() {
...
@@ -107,6 +107,14 @@ func runSetupServer() {
}
}
func
runMainServer
()
{
func
runMainServer
()
{
cfg
,
err
:=
config
.
Load
()
if
err
!=
nil
{
log
.
Fatalf
(
"Failed to load config: %v"
,
err
)
}
if
cfg
.
RunMode
==
config
.
RunModeSimple
{
log
.
Println
(
"⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED"
)
}
buildInfo
:=
handler
.
BuildInfo
{
buildInfo
:=
handler
.
BuildInfo
{
Version
:
Version
,
Version
:
Version
,
BuildType
:
BuildType
,
BuildType
:
BuildType
,
...
...
backend/cmd/server/wire_gen.go
View file @
2bd288a6
...
@@ -49,7 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -49,7 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
emailQueueService
:=
service
.
ProvideEmailQueueService
(
emailService
)
emailQueueService
:=
service
.
ProvideEmailQueueService
(
emailService
)
authService
:=
service
.
NewAuthService
(
userRepository
,
configConfig
,
settingService
,
emailService
,
turnstileService
,
emailQueueService
)
authService
:=
service
.
NewAuthService
(
userRepository
,
configConfig
,
settingService
,
emailService
,
turnstileService
,
emailQueueService
)
userService
:=
service
.
NewUserService
(
userRepository
)
userService
:=
service
.
NewUserService
(
userRepository
)
authHandler
:=
handler
.
NewAuthHandler
(
authService
,
userService
)
authHandler
:=
handler
.
NewAuthHandler
(
configConfig
,
authService
,
userService
)
userHandler
:=
handler
.
NewUserHandler
(
userService
)
userHandler
:=
handler
.
NewUserHandler
(
userService
)
apiKeyRepository
:=
repository
.
NewApiKeyRepository
(
db
)
apiKeyRepository
:=
repository
.
NewApiKeyRepository
(
db
)
groupRepository
:=
repository
.
NewGroupRepository
(
db
)
groupRepository
:=
repository
.
NewGroupRepository
(
db
)
...
@@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
redeemCodeRepository
:=
repository
.
NewRedeemCodeRepository
(
db
)
redeemCodeRepository
:=
repository
.
NewRedeemCodeRepository
(
db
)
billingCache
:=
repository
.
NewBillingCache
(
client
)
billingCache
:=
repository
.
NewBillingCache
(
client
)
billingCacheService
:=
service
.
NewBillingCacheService
(
billingCache
,
userRepository
,
userSubscriptionRepository
)
billingCacheService
:=
service
.
NewBillingCacheService
(
billingCache
,
userRepository
,
userSubscriptionRepository
,
configConfig
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepository
,
userSubscriptionRepository
,
billingCacheService
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepository
,
userSubscriptionRepository
,
billingCacheService
)
redeemCache
:=
repository
.
NewRedeemCache
(
client
)
redeemCache
:=
repository
.
NewRedeemCache
(
client
)
redeemService
:=
service
.
NewRedeemService
(
redeemCodeRepository
,
userRepository
,
subscriptionService
,
redeemCache
,
billingCacheService
)
redeemService
:=
service
.
NewRedeemService
(
redeemCodeRepository
,
userRepository
,
subscriptionService
,
redeemCache
,
billingCacheService
)
...
@@ -132,7 +132,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -132,7 +132,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
handlers
:=
handler
.
ProvideHandlers
(
authHandler
,
userHandler
,
apiKeyHandler
,
usageHandler
,
redeemHandler
,
subscriptionHandler
,
adminHandlers
,
gatewayHandler
,
openAIGatewayHandler
,
handlerSettingHandler
)
handlers
:=
handler
.
ProvideHandlers
(
authHandler
,
userHandler
,
apiKeyHandler
,
usageHandler
,
redeemHandler
,
subscriptionHandler
,
adminHandlers
,
gatewayHandler
,
openAIGatewayHandler
,
handlerSettingHandler
)
jwtAuthMiddleware
:=
middleware
.
NewJWTAuthMiddleware
(
authService
,
userService
)
jwtAuthMiddleware
:=
middleware
.
NewJWTAuthMiddleware
(
authService
,
userService
)
adminAuthMiddleware
:=
middleware
.
NewAdminAuthMiddleware
(
authService
,
userService
,
settingService
)
adminAuthMiddleware
:=
middleware
.
NewAdminAuthMiddleware
(
authService
,
userService
,
settingService
)
apiKeyAuthMiddleware
:=
middleware
.
NewApiKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
)
apiKeyAuthMiddleware
:=
middleware
.
NewApiKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
configConfig
)
engine
:=
server
.
ProvideRouter
(
configConfig
,
handlers
,
jwtAuthMiddleware
,
adminAuthMiddleware
,
apiKeyAuthMiddleware
,
apiKeyService
,
subscriptionService
)
engine
:=
server
.
ProvideRouter
(
configConfig
,
handlers
,
jwtAuthMiddleware
,
adminAuthMiddleware
,
apiKeyAuthMiddleware
,
apiKeyService
,
subscriptionService
)
httpServer
:=
server
.
ProvideHTTPServer
(
configConfig
,
engine
)
httpServer
:=
server
.
ProvideHTTPServer
(
configConfig
,
engine
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
configConfig
)
...
...
backend/internal/config/config.go
View file @
2bd288a6
...
@@ -7,6 +7,11 @@ import (
...
@@ -7,6 +7,11 @@ import (
"github.com/spf13/viper"
"github.com/spf13/viper"
)
)
const
(
RunModeStandard
=
"standard"
RunModeSimple
=
"simple"
)
type
Config
struct
{
type
Config
struct
{
Server
ServerConfig
`mapstructure:"server"`
Server
ServerConfig
`mapstructure:"server"`
Database
DatabaseConfig
`mapstructure:"database"`
Database
DatabaseConfig
`mapstructure:"database"`
...
@@ -17,6 +22,7 @@ type Config struct {
...
@@ -17,6 +22,7 @@ type Config struct {
Pricing
PricingConfig
`mapstructure:"pricing"`
Pricing
PricingConfig
`mapstructure:"pricing"`
Gateway
GatewayConfig
`mapstructure:"gateway"`
Gateway
GatewayConfig
`mapstructure:"gateway"`
TokenRefresh
TokenRefreshConfig
`mapstructure:"token_refresh"`
TokenRefresh
TokenRefreshConfig
`mapstructure:"token_refresh"`
RunMode
string
`mapstructure:"run_mode" yaml:"run_mode"`
Timezone
string
`mapstructure:"timezone"`
// e.g. "Asia/Shanghai", "UTC"
Timezone
string
`mapstructure:"timezone"`
// e.g. "Asia/Shanghai", "UTC"
Gemini
GeminiConfig
`mapstructure:"gemini"`
Gemini
GeminiConfig
`mapstructure:"gemini"`
}
}
...
@@ -135,6 +141,16 @@ type RateLimitConfig struct {
...
@@ -135,6 +141,16 @@ type RateLimitConfig struct {
OverloadCooldownMinutes
int
`mapstructure:"overload_cooldown_minutes"`
// 529过载冷却时间(分钟)
OverloadCooldownMinutes
int
`mapstructure:"overload_cooldown_minutes"`
// 529过载冷却时间(分钟)
}
}
func
NormalizeRunMode
(
value
string
)
string
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
value
))
switch
normalized
{
case
RunModeStandard
,
RunModeSimple
:
return
normalized
default
:
return
RunModeStandard
}
}
func
Load
()
(
*
Config
,
error
)
{
func
Load
()
(
*
Config
,
error
)
{
viper
.
SetConfigName
(
"config"
)
viper
.
SetConfigName
(
"config"
)
viper
.
SetConfigType
(
"yaml"
)
viper
.
SetConfigType
(
"yaml"
)
...
@@ -161,6 +177,8 @@ func Load() (*Config, error) {
...
@@ -161,6 +177,8 @@ func Load() (*Config, error) {
return
nil
,
fmt
.
Errorf
(
"unmarshal config error: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"unmarshal config error: %w"
,
err
)
}
}
cfg
.
RunMode
=
NormalizeRunMode
(
cfg
.
RunMode
)
if
err
:=
cfg
.
Validate
();
err
!=
nil
{
if
err
:=
cfg
.
Validate
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"validate config error: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"validate config error: %w"
,
err
)
}
}
...
@@ -169,6 +187,8 @@ func Load() (*Config, error) {
...
@@ -169,6 +187,8 @@ func Load() (*Config, error) {
}
}
func
setDefaults
()
{
func
setDefaults
()
{
viper
.
SetDefault
(
"run_mode"
,
RunModeStandard
)
// Server
// Server
viper
.
SetDefault
(
"server.host"
,
"0.0.0.0"
)
viper
.
SetDefault
(
"server.host"
,
"0.0.0.0"
)
viper
.
SetDefault
(
"server.port"
,
8080
)
viper
.
SetDefault
(
"server.port"
,
8080
)
...
...
backend/internal/config/config_test.go
0 → 100644
View file @
2bd288a6
package
config
import
"testing"
func
TestNormalizeRunMode
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
input
string
expected
string
}{
{
"simple"
,
"simple"
},
{
"SIMPLE"
,
"simple"
},
{
"standard"
,
"standard"
},
{
"invalid"
,
"standard"
},
{
""
,
"standard"
},
}
for
_
,
tt
:=
range
tests
{
result
:=
NormalizeRunMode
(
tt
.
input
)
if
result
!=
tt
.
expected
{
t
.
Errorf
(
"NormalizeRunMode(%q) = %q, want %q"
,
tt
.
input
,
result
,
tt
.
expected
)
}
}
}
backend/internal/handler/auth_handler.go
View file @
2bd288a6
package
handler
package
handler
import
(
import
(
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
...
@@ -11,13 +12,15 @@ import (
...
@@ -11,13 +12,15 @@ import (
// AuthHandler handles authentication-related requests
// AuthHandler handles authentication-related requests
type
AuthHandler
struct
{
type
AuthHandler
struct
{
cfg
*
config
.
Config
authService
*
service
.
AuthService
authService
*
service
.
AuthService
userService
*
service
.
UserService
userService
*
service
.
UserService
}
}
// NewAuthHandler creates a new AuthHandler
// NewAuthHandler creates a new AuthHandler
func
NewAuthHandler
(
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
)
*
AuthHandler
{
func
NewAuthHandler
(
cfg
*
config
.
Config
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
)
*
AuthHandler
{
return
&
AuthHandler
{
return
&
AuthHandler
{
cfg
:
cfg
,
authService
:
authService
,
authService
:
authService
,
userService
:
userService
,
userService
:
userService
,
}
}
...
@@ -157,5 +160,15 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
...
@@ -157,5 +160,15 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
return
return
}
}
response
.
Success
(
c
,
dto
.
UserFromService
(
user
))
type
UserResponse
struct
{
*
dto
.
User
RunMode
string
`json:"run_mode"`
}
runMode
:=
config
.
RunModeStandard
if
h
.
cfg
!=
nil
{
runMode
=
h
.
cfg
.
RunMode
}
response
.
Success
(
c
,
UserResponse
{
User
:
dto
.
UserFromService
(
user
),
RunMode
:
runMode
})
}
}
backend/internal/repository/auto_migrate.go
View file @
2bd288a6
...
@@ -30,6 +30,11 @@ func AutoMigrate(db *gorm.DB) error {
...
@@ -30,6 +30,11 @@ func AutoMigrate(db *gorm.DB) error {
return
err
return
err
}
}
// 创建默认分组(简易模式支持)
if
err
:=
ensureDefaultGroups
(
db
);
err
!=
nil
{
return
err
}
// 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败)
// 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败)
return
fixInvalidExpiresAt
(
db
)
return
fixInvalidExpiresAt
(
db
)
}
}
...
@@ -47,3 +52,55 @@ func fixInvalidExpiresAt(db *gorm.DB) error {
...
@@ -47,3 +52,55 @@ func fixInvalidExpiresAt(db *gorm.DB) error {
}
}
return
nil
return
nil
}
}
// ensureDefaultGroups 确保默认分组存在(简易模式支持)
// 为每个平台创建一个默认分组,配置最大权限以确保简易模式下不受限制
func
ensureDefaultGroups
(
db
*
gorm
.
DB
)
error
{
defaultGroups
:=
[]
struct
{
name
string
platform
string
description
string
}{
{
name
:
"anthropic-default"
,
platform
:
"anthropic"
,
description
:
"Default group for Anthropic accounts (Simple Mode)"
,
},
{
name
:
"openai-default"
,
platform
:
"openai"
,
description
:
"Default group for OpenAI accounts (Simple Mode)"
,
},
{
name
:
"gemini-default"
,
platform
:
"gemini"
,
description
:
"Default group for Gemini accounts (Simple Mode)"
,
},
}
for
_
,
dg
:=
range
defaultGroups
{
var
count
int64
if
err
:=
db
.
Model
(
&
groupModel
{})
.
Where
(
"name = ?"
,
dg
.
name
)
.
Count
(
&
count
)
.
Error
;
err
!=
nil
{
return
err
}
if
count
==
0
{
group
:=
&
groupModel
{
Name
:
dg
.
name
,
Description
:
dg
.
description
,
Platform
:
dg
.
platform
,
RateMultiplier
:
1.0
,
IsExclusive
:
false
,
Status
:
"active"
,
SubscriptionType
:
"standard"
,
}
if
err
:=
db
.
Create
(
group
)
.
Error
;
err
!=
nil
{
log
.
Printf
(
"[AutoMigrate] Failed to create default group %s: %v"
,
dg
.
name
,
err
)
return
err
}
log
.
Printf
(
"[AutoMigrate] Created default group: %s (platform: %s)"
,
dg
.
name
,
dg
.
platform
)
}
}
return
nil
}
backend/internal/repository/group_repo_integration_test.go
View file @
2bd288a6
...
@@ -82,8 +82,9 @@ func (s *GroupRepoSuite) TestList() {
...
@@ -82,8 +82,9 @@ func (s *GroupRepoSuite) TestList() {
groups
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
groups
,
page
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
NoError
(
err
,
"List"
)
s
.
Require
()
.
Len
(
groups
,
2
)
// 3 default groups + 2 test groups = 5 total
s
.
Require
()
.
Equal
(
int64
(
2
),
page
.
Total
)
s
.
Require
()
.
Len
(
groups
,
5
)
s
.
Require
()
.
Equal
(
int64
(
5
),
page
.
Total
)
}
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Platform
()
{
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Platform
()
{
...
@@ -92,8 +93,12 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
...
@@ -92,8 +93,12 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
nil
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
// 1 default openai group + 1 test openai group = 2 total
s
.
Require
()
.
Equal
(
service
.
PlatformOpenAI
,
groups
[
0
]
.
Platform
)
s
.
Require
()
.
Len
(
groups
,
2
)
// Verify all groups are OpenAI platform
for
_
,
g
:=
range
groups
{
s
.
Require
()
.
Equal
(
service
.
PlatformOpenAI
,
g
.
Platform
)
}
}
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Status
()
{
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Status
()
{
...
@@ -151,8 +156,17 @@ func (s *GroupRepoSuite) TestListActive() {
...
@@ -151,8 +156,17 @@ func (s *GroupRepoSuite) TestListActive() {
groups
,
err
:=
s
.
repo
.
ListActive
(
s
.
ctx
)
groups
,
err
:=
s
.
repo
.
ListActive
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
,
"ListActive"
)
s
.
Require
()
.
NoError
(
err
,
"ListActive"
)
s
.
Require
()
.
Len
(
groups
,
1
)
// 3 default groups (all active) + 1 test active group = 4 total
s
.
Require
()
.
Equal
(
"active1"
,
groups
[
0
]
.
Name
)
s
.
Require
()
.
Len
(
groups
,
4
)
// Verify our test group is in the results
var
found
bool
for
_
,
g
:=
range
groups
{
if
g
.
Name
==
"active1"
{
found
=
true
break
}
}
s
.
Require
()
.
True
(
found
,
"active1 group should be in results"
)
}
}
func
(
s
*
GroupRepoSuite
)
TestListActiveByPlatform
()
{
func
(
s
*
GroupRepoSuite
)
TestListActiveByPlatform
()
{
...
@@ -162,8 +176,17 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
...
@@ -162,8 +176,17 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
groups
,
err
:=
s
.
repo
.
ListActiveByPlatform
(
s
.
ctx
,
service
.
PlatformAnthropic
)
groups
,
err
:=
s
.
repo
.
ListActiveByPlatform
(
s
.
ctx
,
service
.
PlatformAnthropic
)
s
.
Require
()
.
NoError
(
err
,
"ListActiveByPlatform"
)
s
.
Require
()
.
NoError
(
err
,
"ListActiveByPlatform"
)
s
.
Require
()
.
Len
(
groups
,
1
)
// 1 default anthropic group + 1 test active anthropic group = 2 total
s
.
Require
()
.
Equal
(
"g1"
,
groups
[
0
]
.
Name
)
s
.
Require
()
.
Len
(
groups
,
2
)
// Verify our test group is in the results
var
found
bool
for
_
,
g
:=
range
groups
{
if
g
.
Name
==
"g1"
{
found
=
true
break
}
}
s
.
Require
()
.
True
(
found
,
"g1 group should be in results"
)
}
}
// --- ExistsByName ---
// --- ExistsByName ---
...
...
backend/internal/server/api_contract_test.go
View file @
2bd288a6
...
@@ -59,7 +59,8 @@ func TestAPIContracts(t *testing.T) {
...
@@ -59,7 +59,8 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"status": "active",
"allowed_groups": null,
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
"updated_at": "2025-01-02T03:04:05Z",
"run_mode": "standard"
}
}
}`
,
}`
,
},
},
...
@@ -369,6 +370,7 @@ func newContractDeps(t *testing.T) *contractDeps {
...
@@ -369,6 +370,7 @@ func newContractDeps(t *testing.T) *contractDeps {
Default
:
config
.
DefaultConfig
{
Default
:
config
.
DefaultConfig
{
ApiKeyPrefix
:
"sk-"
,
ApiKeyPrefix
:
"sk-"
,
},
},
RunMode
:
config
.
RunModeStandard
,
}
}
userService
:=
service
.
NewUserService
(
userRepo
)
userService
:=
service
.
NewUserService
(
userRepo
)
...
@@ -380,7 +382,7 @@ func newContractDeps(t *testing.T) *contractDeps {
...
@@ -380,7 +382,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo
:=
newStubSettingRepo
()
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
authHandler
:=
handler
.
NewAuthHandler
(
nil
,
userService
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
)
...
...
backend/internal/server/http.go
View file @
2bd288a6
...
@@ -36,7 +36,7 @@ func ProvideRouter(
...
@@ -36,7 +36,7 @@ func ProvideRouter(
r
:=
gin
.
New
()
r
:=
gin
.
New
()
r
.
Use
(
middleware2
.
Recovery
())
r
.
Use
(
middleware2
.
Recovery
())
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
)
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
}
}
// ProvideHTTPServer 提供 HTTP 服务器
// ProvideHTTPServer 提供 HTTP 服务器
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
2bd288a6
...
@@ -5,18 +5,19 @@ import (
...
@@ -5,18 +5,19 @@ import (
"log"
"log"
"strings"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
)
)
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func
NewApiKeyAuthMiddleware
(
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
)
ApiKeyAuthMiddleware
{
func
NewApiKeyAuthMiddleware
(
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
ApiKeyAuthMiddleware
{
return
ApiKeyAuthMiddleware
(
apiKeyAuthWithSubscription
(
apiKeyService
,
subscriptionService
))
return
ApiKeyAuthMiddleware
(
apiKeyAuthWithSubscription
(
apiKeyService
,
subscriptionService
,
cfg
))
}
}
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func
apiKeyAuthWithSubscription
(
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
)
gin
.
HandlerFunc
{
func
apiKeyAuthWithSubscription
(
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
return
func
(
c
*
gin
.
Context
)
{
// 尝试从Authorization header中提取API key (Bearer scheme)
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
...
@@ -85,6 +86,18 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
...
@@ -85,6 +86,18 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
return
return
}
}
if
cfg
.
RunMode
==
config
.
RunModeSimple
{
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c
.
Set
(
string
(
ContextKeyApiKey
),
apiKey
)
c
.
Set
(
string
(
ContextKeyUser
),
AuthSubject
{
UserID
:
apiKey
.
User
.
ID
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
c
.
Next
()
return
}
// 判断计费方式:订阅模式 vs 余额模式
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionType
:=
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
isSubscriptionType
:=
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
...
...
backend/internal/server/middleware/api_key_auth_google.go
View file @
2bd288a6
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
"errors"
"errors"
"strings"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
...
@@ -11,15 +12,15 @@ import (
...
@@ -11,15 +12,15 @@ import (
)
)
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
func
ApiKeyAuthGoogle
(
apiKeyService
*
service
.
ApiKeyService
)
gin
.
HandlerFunc
{
func
ApiKeyAuthGoogle
(
apiKeyService
*
service
.
ApiKeyService
,
cfg
*
config
.
Config
)
gin
.
HandlerFunc
{
return
ApiKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
)
return
ApiKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
)
}
}
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
//
//
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func
ApiKeyAuthWithSubscriptionGoogle
(
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
)
gin
.
HandlerFunc
{
func
ApiKeyAuthWithSubscriptionGoogle
(
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
return
func
(
c
*
gin
.
Context
)
{
apiKeyString
:=
extractAPIKeyFromRequest
(
c
)
apiKeyString
:=
extractAPIKeyFromRequest
(
c
)
if
apiKeyString
==
""
{
if
apiKeyString
==
""
{
...
@@ -50,6 +51,18 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
...
@@ -50,6 +51,18 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
return
return
}
}
// 简易模式:跳过余额和订阅检查
if
cfg
.
RunMode
==
config
.
RunModeSimple
{
c
.
Set
(
string
(
ContextKeyApiKey
),
apiKey
)
c
.
Set
(
string
(
ContextKeyUser
),
AuthSubject
{
UserID
:
apiKey
.
User
.
ID
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
c
.
Next
()
return
}
isSubscriptionType
:=
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
isSubscriptionType
:=
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
if
isSubscriptionType
&&
subscriptionService
!=
nil
{
if
isSubscriptionType
&&
subscriptionService
!=
nil
{
subscription
,
err
:=
subscriptionService
.
GetActiveSubscription
(
subscription
,
err
:=
subscriptionService
.
GetActiveSubscription
(
...
...
backend/internal/server/middleware/api_key_auth_test.go
0 → 100644
View file @
2bd288a6
//go:build unit
package
middleware
import
(
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestSimpleModeBypassesQuotaCheck
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
limit
:=
1.0
group
:=
&
service
.
Group
{
ID
:
42
,
Name
:
"sub"
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeSubscription
,
DailyLimitUSD
:
&
limit
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
ApiKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
ApiKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrApiKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
t
.
Run
(
"simple_mode_bypasses_quota_check"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewApiKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
t
.
Run
(
"standard_mode_enforces_quota_check"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}
apiKeyService
:=
service
.
NewApiKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
now
:=
time
.
Now
()
sub
:=
&
service
.
UserSubscription
{
ID
:
55
,
UserID
:
user
.
ID
,
GroupID
:
group
.
ID
,
Status
:
service
.
SubscriptionStatusActive
,
ExpiresAt
:
now
.
Add
(
24
*
time
.
Hour
),
DailyWindowStart
:
&
now
,
DailyUsageUSD
:
10
,
}
subscriptionRepo
:=
&
stubUserSubscriptionRepo
{
getActive
:
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
{
if
userID
!=
sub
.
UserID
||
groupID
!=
sub
.
GroupID
{
return
nil
,
service
.
ErrSubscriptionNotFound
}
clone
:=
*
sub
return
&
clone
,
nil
},
updateStatus
:
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
{
return
nil
},
activateWindow
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetDaily
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetWeekly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetMonthly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
}
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
w
.
Code
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"USAGE_LIMIT_EXCEEDED"
)
})
}
func
newAuthTestRouter
(
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewApiKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
return
router
}
type
stubApiKeyRepo
struct
{
getByKey
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
ApiKey
,
error
)
}
func
(
r
*
stubApiKeyRepo
)
Create
(
ctx
context
.
Context
,
key
*
service
.
ApiKey
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
ApiKey
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
ApiKey
,
error
)
{
if
r
.
getByKey
!=
nil
{
return
r
.
getByKey
(
ctx
,
key
)
}
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
Update
(
ctx
context
.
Context
,
key
*
service
.
ApiKey
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
VerifyOwnership
(
ctx
context
.
Context
,
userID
int64
,
apiKeyIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
CountByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
ExistsByKey
(
ctx
context
.
Context
,
key
string
)
(
bool
,
error
)
{
return
false
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
SearchApiKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
service
.
ApiKey
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
ClearGroupIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
stubUserSubscriptionRepo
struct
{
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
activateWindow
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
resetDaily
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
resetWeekly
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
resetMonthly
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
}
func
(
r
*
stubUserSubscriptionRepo
)
Create
(
ctx
context
.
Context
,
sub
*
service
.
UserSubscription
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
GetByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
GetActiveByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
{
if
r
.
getActive
!=
nil
{
return
r
.
getActive
(
ctx
,
userID
,
groupID
)
}
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
Update
(
ctx
context
.
Context
,
sub
*
service
.
UserSubscription
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
{
return
false
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
ExtendExpiry
(
ctx
context
.
Context
,
subscriptionID
int64
,
newExpiresAt
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
UpdateStatus
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
{
if
r
.
updateStatus
!=
nil
{
return
r
.
updateStatus
(
ctx
,
subscriptionID
,
status
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
UpdateNotes
(
ctx
context
.
Context
,
subscriptionID
int64
,
notes
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
ActivateWindows
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
if
r
.
activateWindow
!=
nil
{
return
r
.
activateWindow
(
ctx
,
id
,
start
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
ResetDailyUsage
(
ctx
context
.
Context
,
id
int64
,
newWindowStart
time
.
Time
)
error
{
if
r
.
resetDaily
!=
nil
{
return
r
.
resetDaily
(
ctx
,
id
,
newWindowStart
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
ResetWeeklyUsage
(
ctx
context
.
Context
,
id
int64
,
newWindowStart
time
.
Time
)
error
{
if
r
.
resetWeekly
!=
nil
{
return
r
.
resetWeekly
(
ctx
,
id
,
newWindowStart
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
ResetMonthlyUsage
(
ctx
context
.
Context
,
id
int64
,
newWindowStart
time
.
Time
)
error
{
if
r
.
resetMonthly
!=
nil
{
return
r
.
resetMonthly
(
ctx
,
id
,
newWindowStart
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
IncrementUsage
(
ctx
context
.
Context
,
id
int64
,
costUSD
float64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserSubscriptionRepo
)
BatchUpdateExpiredStatus
(
ctx
context
.
Context
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
backend/internal/server/router.go
View file @
2bd288a6
package
server
package
server
import
(
import
(
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/server/routes"
"github.com/Wei-Shaw/sub2api/internal/server/routes"
...
@@ -19,6 +20,7 @@ func SetupRouter(
...
@@ -19,6 +20,7 @@ func SetupRouter(
apiKeyAuth
middleware2
.
ApiKeyAuthMiddleware
,
apiKeyAuth
middleware2
.
ApiKeyAuthMiddleware
,
apiKeyService
*
service
.
ApiKeyService
,
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
,
)
*
gin
.
Engine
{
)
*
gin
.
Engine
{
// 应用中间件
// 应用中间件
r
.
Use
(
middleware2
.
Logger
())
r
.
Use
(
middleware2
.
Logger
())
...
@@ -30,7 +32,7 @@ func SetupRouter(
...
@@ -30,7 +32,7 @@ func SetupRouter(
}
}
// 注册路由
// 注册路由
registerRoutes
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
)
registerRoutes
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
return
r
return
r
}
}
...
@@ -44,6 +46,7 @@ func registerRoutes(
...
@@ -44,6 +46,7 @@ func registerRoutes(
apiKeyAuth
middleware2
.
ApiKeyAuthMiddleware
,
apiKeyAuth
middleware2
.
ApiKeyAuthMiddleware
,
apiKeyService
*
service
.
ApiKeyService
,
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
,
)
{
)
{
// 通用路由(健康检查、状态等)
// 通用路由(健康检查、状态等)
routes
.
RegisterCommonRoutes
(
r
)
routes
.
RegisterCommonRoutes
(
r
)
...
@@ -55,5 +58,5 @@ func registerRoutes(
...
@@ -55,5 +58,5 @@ func registerRoutes(
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
}
}
backend/internal/server/routes/gateway.go
View file @
2bd288a6
package
routes
package
routes
import
(
import
(
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
...
@@ -15,6 +16,7 @@ func RegisterGatewayRoutes(
...
@@ -15,6 +16,7 @@ func RegisterGatewayRoutes(
apiKeyAuth
middleware
.
ApiKeyAuthMiddleware
,
apiKeyAuth
middleware
.
ApiKeyAuthMiddleware
,
apiKeyService
*
service
.
ApiKeyService
,
apiKeyService
*
service
.
ApiKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
,
)
{
)
{
// API网关(Claude API兼容)
// API网关(Claude API兼容)
gateway
:=
r
.
Group
(
"/v1"
)
gateway
:=
r
.
Group
(
"/v1"
)
...
@@ -30,7 +32,7 @@ func RegisterGatewayRoutes(
...
@@ -30,7 +32,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
gemini
:=
r
.
Group
(
"/v1beta"
)
gemini
:=
r
.
Group
(
"/v1beta"
)
gemini
.
Use
(
middleware
.
ApiKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
subscriptionService
))
gemini
.
Use
(
middleware
.
ApiKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
subscriptionService
,
cfg
))
{
{
gemini
.
GET
(
"/models"
,
h
.
Gateway
.
GeminiV1BetaListModels
)
gemini
.
GET
(
"/models"
,
h
.
Gateway
.
GeminiV1BetaListModels
)
gemini
.
GET
(
"/models/:model"
,
h
.
Gateway
.
GeminiV1BetaGetModel
)
gemini
.
GET
(
"/models/:model"
,
h
.
Gateway
.
GeminiV1BetaGetModel
)
...
@@ -54,7 +56,7 @@ func RegisterGatewayRoutes(
...
@@ -54,7 +56,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta
:=
r
.
Group
(
"/antigravity/v1beta"
)
antigravityV1Beta
:=
r
.
Group
(
"/antigravity/v1beta"
)
antigravityV1Beta
.
Use
(
middleware
.
ForcePlatform
(
service
.
PlatformAntigravity
))
antigravityV1Beta
.
Use
(
middleware
.
ForcePlatform
(
service
.
PlatformAntigravity
))
antigravityV1Beta
.
Use
(
middleware
.
ApiKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
subscriptionService
))
antigravityV1Beta
.
Use
(
middleware
.
ApiKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
subscriptionService
,
cfg
))
{
{
antigravityV1Beta
.
GET
(
"/models"
,
h
.
Gateway
.
GeminiV1BetaListModels
)
antigravityV1Beta
.
GET
(
"/models"
,
h
.
Gateway
.
GeminiV1BetaListModels
)
antigravityV1Beta
.
GET
(
"/models/:model"
,
h
.
Gateway
.
GeminiV1BetaGetModel
)
antigravityV1Beta
.
GET
(
"/models/:model"
,
h
.
Gateway
.
GeminiV1BetaGetModel
)
...
...
backend/internal/service/account_usage_service.go
View file @
2bd288a6
...
@@ -54,15 +54,23 @@ type UsageLogRepository interface {
...
@@ -54,15 +54,23 @@ type UsageLogRepository interface {
GetApiKeyStatsAggregated
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
GetApiKeyStatsAggregated
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
UsageStats
,
error
)
}
}
// usageCache 用于缓存usage数据
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
type
usageCache
struct
{
type
apiUsageCache
struct
{
data
*
UsageInfo
response
*
ClaudeUsageResponse
timestamp
time
.
Time
}
// windowStatsCache 缓存从本地数据库查询的窗口统计(requests, tokens, cost)
type
windowStatsCache
struct
{
stats
*
WindowStats
timestamp
time
.
Time
timestamp
time
.
Time
}
}
var
(
var
(
usageCacheMap
=
sync
.
Map
{}
apiCacheMap
=
sync
.
Map
{}
// 缓存 API 响应
cacheTTL
=
10
*
time
.
Minute
windowStatsCacheMap
=
sync
.
Map
{}
// 缓存窗口统计
apiCacheTTL
=
10
*
time
.
Minute
windowStatsCacheTTL
=
1
*
time
.
Minute
)
)
// WindowStats 窗口期统计
// WindowStats 窗口期统计
...
@@ -126,7 +134,7 @@ func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLog
...
@@ -126,7 +134,7 @@ func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLog
}
}
// GetUsage 获取账号使用量
// GetUsage 获取账号使用量
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),缓存10分钟
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),
API响应
缓存10分钟
,窗口统计缓存1分钟
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
// API Key账号: 不支持usage查询
// API Key账号: 不支持usage查询
func
(
s
*
AccountUsageService
)
GetUsage
(
ctx
context
.
Context
,
accountID
int64
)
(
*
UsageInfo
,
error
)
{
func
(
s
*
AccountUsageService
)
GetUsage
(
ctx
context
.
Context
,
accountID
int64
)
(
*
UsageInfo
,
error
)
{
...
@@ -137,30 +145,34 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
...
@@ -137,30 +145,34 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
// 只有oauth类型账号可以通过API获取usage(有profile scope)
// 只有oauth类型账号可以通过API获取usage(有profile scope)
if
account
.
CanGetUsage
()
{
if
account
.
CanGetUsage
()
{
// 检查缓存
var
apiResp
*
ClaudeUsageResponse
if
cached
,
ok
:=
usageCacheMap
.
Load
(
accountID
);
ok
{
cache
,
ok
:=
cached
.
(
*
usageCache
)
// 1. 检查 API 缓存(10 分钟)
if
!
ok
{
if
cached
,
ok
:=
apiCacheMap
.
Load
(
accountID
);
ok
{
usageCacheMap
.
Delete
(
accountID
)
if
cache
,
ok
:=
cached
.
(
*
apiUsageCache
);
ok
&&
time
.
Since
(
cache
.
timestamp
)
<
apiCacheTTL
{
}
else
if
time
.
Since
(
cache
.
timestamp
)
<
cacheTTL
{
apiResp
=
cache
.
response
return
cache
.
data
,
nil
}
}
}
}
// 从API获取数据
// 2. 如果没有缓存,从 API 获取
usage
,
err
:=
s
.
fetchOAuthUsage
(
ctx
,
account
)
if
apiResp
==
nil
{
if
err
!=
nil
{
apiResp
,
err
=
s
.
fetchOAuthUsageRaw
(
ctx
,
account
)
return
nil
,
err
if
err
!=
nil
{
return
nil
,
err
}
// 缓存 API 响应
apiCacheMap
.
Store
(
accountID
,
&
apiUsageCache
{
response
:
apiResp
,
timestamp
:
time
.
Now
(),
})
}
}
// 添加5h窗口统计数据
// 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds)
s
.
addWindowStats
(
ctx
,
account
,
usage
)
now
:=
time
.
Now
()
usage
:=
s
.
buildUsageInfo
(
apiResp
,
&
now
)
// 缓存结果
// 4. 添加窗口统计(有独立缓存,1 分钟)
usageCacheMap
.
Store
(
accountID
,
&
usageCache
{
s
.
addWindowStats
(
ctx
,
account
,
usage
)
data
:
usage
,
timestamp
:
time
.
Now
(),
})
return
usage
,
nil
return
usage
,
nil
}
}
...
@@ -177,31 +189,54 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
...
@@ -177,31 +189,54 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return
nil
,
fmt
.
Errorf
(
"account type %s does not support usage query"
,
account
.
Type
)
return
nil
,
fmt
.
Errorf
(
"account type %s does not support usage query"
,
account
.
Type
)
}
}
// addWindowStats 为usage数据添加窗口期统计
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存(1 分钟),与 API 缓存分离
func
(
s
*
AccountUsageService
)
addWindowStats
(
ctx
context
.
Context
,
account
*
Account
,
usage
*
UsageInfo
)
{
func
(
s
*
AccountUsageService
)
addWindowStats
(
ctx
context
.
Context
,
account
*
Account
,
usage
*
UsageInfo
)
{
if
usage
.
FiveHour
==
nil
{
// 修复:即使 FiveHour 为 nil,也要尝试获取统计数据
// 因为 SevenDay/SevenDaySonnet 可能需要
if
usage
.
FiveHour
==
nil
&&
usage
.
SevenDay
==
nil
&&
usage
.
SevenDaySonnet
==
nil
{
return
return
}
}
// 使用session_window_start作为统计起始时间
// 检查窗口统计缓存(1 分钟)
var
startTime
time
.
Time
var
windowStats
*
WindowStats
if
account
.
SessionWindowStart
!=
nil
{
if
cached
,
ok
:=
windowStatsCacheMap
.
Load
(
account
.
ID
);
ok
{
startTime
=
*
account
.
SessionWindowStart
if
cache
,
ok
:=
cached
.
(
*
windowStatsCache
);
ok
&&
time
.
Since
(
cache
.
timestamp
)
<
windowStatsCacheTTL
{
}
else
{
windowStats
=
cache
.
stats
// 如果没有窗口信息,使用5小时前作为默认
}
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
}
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
// 如果没有缓存,从数据库查询
if
err
!=
nil
{
if
windowStats
==
nil
{
log
.
Printf
(
"Failed to get window stats for account %d: %v"
,
account
.
ID
,
err
)
var
startTime
time
.
Time
return
if
account
.
SessionWindowStart
!=
nil
{
startTime
=
*
account
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
if
err
!=
nil
{
log
.
Printf
(
"Failed to get window stats for account %d: %v"
,
account
.
ID
,
err
)
return
}
windowStats
=
&
WindowStats
{
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
}
// 缓存窗口统计(1 分钟)
windowStatsCacheMap
.
Store
(
account
.
ID
,
&
windowStatsCache
{
stats
:
windowStats
,
timestamp
:
time
.
Now
(),
})
}
}
usage
.
FiveHour
.
WindowStats
=
&
WindowStats
{
// 为 FiveHour 添加 WindowStats(5h 窗口统计)
Requests
:
stats
.
Requests
,
if
usage
.
FiveHour
!=
nil
{
Tokens
:
stats
.
Tokens
,
usage
.
FiveHour
.
WindowStats
=
windowStats
Cost
:
stats
.
Cost
,
}
}
}
}
...
@@ -227,8 +262,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
...
@@ -227,8 +262,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
return
stats
,
nil
return
stats
,
nil
}
}
// fetchOAuthUsage 从Anthropic API获取
OAuth账号的使用量
// fetchOAuthUsage
Raw
从
Anthropic API
获取
原始响应(不构建 UsageInfo)
func
(
s
*
AccountUsageService
)
fetchOAuthUsage
(
ctx
context
.
Context
,
account
*
Account
)
(
*
UsageInfo
,
error
)
{
func
(
s
*
AccountUsageService
)
fetchOAuthUsage
Raw
(
ctx
context
.
Context
,
account
*
Account
)
(
*
ClaudeUsageResponse
,
error
)
{
accessToken
:=
account
.
GetCredential
(
"access_token"
)
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
if
accessToken
==
""
{
return
nil
,
fmt
.
Errorf
(
"no access token available"
)
return
nil
,
fmt
.
Errorf
(
"no access token available"
)
...
@@ -239,13 +274,7 @@ func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Acco
...
@@ -239,13 +274,7 @@ func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Acco
proxyURL
=
account
.
Proxy
.
URL
()
proxyURL
=
account
.
Proxy
.
URL
()
}
}
usageResp
,
err
:=
s
.
usageFetcher
.
FetchUsage
(
ctx
,
accessToken
,
proxyURL
)
return
s
.
usageFetcher
.
FetchUsage
(
ctx
,
accessToken
,
proxyURL
)
if
err
!=
nil
{
return
nil
,
err
}
now
:=
time
.
Now
()
return
s
.
buildUsageInfo
(
usageResp
,
&
now
),
nil
}
}
// parseTime 尝试多种格式解析时间
// parseTime 尝试多种格式解析时间
...
@@ -270,20 +299,16 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
...
@@ -270,20 +299,16 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
UpdatedAt
:
updatedAt
,
UpdatedAt
:
updatedAt
,
}
}
// 5小时窗口
// 5小时窗口 - 始终创建对象(即使 ResetsAt 为空)
info
.
FiveHour
=
&
UsageProgress
{
Utilization
:
resp
.
FiveHour
.
Utilization
,
}
if
resp
.
FiveHour
.
ResetsAt
!=
""
{
if
resp
.
FiveHour
.
ResetsAt
!=
""
{
if
fiveHourReset
,
err
:=
parseTime
(
resp
.
FiveHour
.
ResetsAt
);
err
==
nil
{
if
fiveHourReset
,
err
:=
parseTime
(
resp
.
FiveHour
.
ResetsAt
);
err
==
nil
{
info
.
FiveHour
=
&
UsageProgress
{
info
.
FiveHour
.
ResetsAt
=
&
fiveHourReset
Utilization
:
resp
.
FiveHour
.
Utilization
,
info
.
FiveHour
.
RemainingSeconds
=
int
(
time
.
Until
(
fiveHourReset
)
.
Seconds
())
ResetsAt
:
&
fiveHourReset
,
RemainingSeconds
:
int
(
time
.
Until
(
fiveHourReset
)
.
Seconds
()),
}
}
else
{
}
else
{
log
.
Printf
(
"Failed to parse FiveHour.ResetsAt: %s, error: %v"
,
resp
.
FiveHour
.
ResetsAt
,
err
)
log
.
Printf
(
"Failed to parse FiveHour.ResetsAt: %s, error: %v"
,
resp
.
FiveHour
.
ResetsAt
,
err
)
// 即使解析失败也返回utilization
info
.
FiveHour
=
&
UsageProgress
{
Utilization
:
resp
.
FiveHour
.
Utilization
,
}
}
}
}
}
...
...
backend/internal/service/admin_service.go
View file @
2bd288a6
...
@@ -609,12 +609,30 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
...
@@ -609,12 +609,30 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
if
err
:=
s
.
accountRepo
.
Create
(
ctx
,
account
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
Create
(
ctx
,
account
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
// 绑定分组
// 绑定分组
if
len
(
input
.
GroupIDs
)
>
0
{
groupIDs
:=
input
.
GroupIDs
if
err
:=
s
.
accountRepo
.
BindGroups
(
ctx
,
account
.
ID
,
input
.
GroupIDs
);
err
!=
nil
{
// 如果没有指定分组,自动绑定对应平台的默认分组
if
len
(
groupIDs
)
==
0
{
defaultGroupName
:=
input
.
Platform
+
"-default"
groups
,
err
:=
s
.
groupRepo
.
ListActiveByPlatform
(
ctx
,
input
.
Platform
)
if
err
==
nil
{
for
_
,
g
:=
range
groups
{
if
g
.
Name
==
defaultGroupName
{
groupIDs
=
[]
int64
{
g
.
ID
}
log
.
Printf
(
"[CreateAccount] Auto-binding account %d to default group %s (ID: %d)"
,
account
.
ID
,
defaultGroupName
,
g
.
ID
)
break
}
}
}
}
if
len
(
groupIDs
)
>
0
{
if
err
:=
s
.
accountRepo
.
BindGroups
(
ctx
,
account
.
ID
,
groupIDs
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
}
}
return
account
,
nil
return
account
,
nil
}
}
...
...
backend/internal/service/billing_cache_service.go
View file @
2bd288a6
...
@@ -6,6 +6,7 @@ import (
...
@@ -6,6 +6,7 @@ import (
"log"
"log"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
)
)
...
@@ -32,14 +33,16 @@ type BillingCacheService struct {
...
@@ -32,14 +33,16 @@ type BillingCacheService struct {
cache
BillingCache
cache
BillingCache
userRepo
UserRepository
userRepo
UserRepository
subRepo
UserSubscriptionRepository
subRepo
UserSubscriptionRepository
cfg
*
config
.
Config
}
}
// NewBillingCacheService 创建计费缓存服务
// NewBillingCacheService 创建计费缓存服务
func
NewBillingCacheService
(
cache
BillingCache
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
)
*
BillingCacheService
{
func
NewBillingCacheService
(
cache
BillingCache
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
,
cfg
*
config
.
Config
)
*
BillingCacheService
{
return
&
BillingCacheService
{
return
&
BillingCacheService
{
cache
:
cache
,
cache
:
cache
,
userRepo
:
userRepo
,
userRepo
:
userRepo
,
subRepo
:
subRepo
,
subRepo
:
subRepo
,
cfg
:
cfg
,
}
}
}
}
...
@@ -224,6 +227,11 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
...
@@ -224,6 +227,11 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// 余额模式:检查缓存余额 > 0
// 余额模式:检查缓存余额 > 0
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
func
(
s
*
BillingCacheService
)
CheckBillingEligibility
(
ctx
context
.
Context
,
user
*
User
,
apiKey
*
ApiKey
,
group
*
Group
,
subscription
*
UserSubscription
)
error
{
func
(
s
*
BillingCacheService
)
CheckBillingEligibility
(
ctx
context
.
Context
,
user
*
User
,
apiKey
*
ApiKey
,
group
*
Group
,
subscription
*
UserSubscription
)
error
{
// 简易模式:跳过所有计费检查
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
return
nil
}
// 判断计费模式
// 判断计费模式
isSubscriptionMode
:=
group
!=
nil
&&
group
.
IsSubscriptionType
()
&&
subscription
!=
nil
isSubscriptionMode
:=
group
!=
nil
&&
group
.
IsSubscriptionType
()
&&
subscription
!=
nil
...
...
backend/internal/service/gateway_service.go
View file @
2bd288a6
...
@@ -357,7 +357,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
...
@@ -357,7 +357,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 2. 获取可调度账号列表(单平台)
// 2. 获取可调度账号列表(单平台)
var
accounts
[]
Account
var
accounts
[]
Account
var
err
error
var
err
error
if
groupID
!=
nil
{
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
// 简易模式:忽略 groupID,查询所有可用账号
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
else
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
platform
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
platform
)
}
else
{
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
platform
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
platform
)
...
@@ -1226,6 +1229,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
...
@@ -1226,6 +1229,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log
.
Printf
(
"Create usage log failed: %v"
,
err
)
log
.
Printf
(
"Create usage log failed: %v"
,
err
)
}
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
log
.
Printf
(
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
// 根据计费类型执行扣费
// 根据计费类型执行扣费
if
isSubscriptionBilling
{
if
isSubscriptionBilling
{
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
...
...
backend/internal/service/openai_gateway_service.go
View file @
2bd288a6
...
@@ -10,6 +10,7 @@ import (
...
@@ -10,6 +10,7 @@ import (
"errors"
"errors"
"fmt"
"fmt"
"io"
"io"
"log"
"net/http"
"net/http"
"regexp"
"regexp"
"strconv"
"strconv"
...
@@ -155,7 +156,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
...
@@ -155,7 +156,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
// 2. Get schedulable OpenAI accounts
// 2. Get schedulable OpenAI accounts
var
accounts
[]
Account
var
accounts
[]
Account
var
err
error
var
err
error
if
groupID
!=
nil
{
// 简易模式:忽略分组限制,查询所有可用账号
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
PlatformOpenAI
)
}
else
if
groupID
!=
nil
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
PlatformOpenAI
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByGroupIDAndPlatform
(
ctx
,
*
groupID
,
PlatformOpenAI
)
}
else
{
}
else
{
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
PlatformOpenAI
)
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
PlatformOpenAI
)
...
@@ -754,6 +758,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
...
@@ -754,6 +758,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
_
=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
_
=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
log
.
Printf
(
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
// Deduct based on billing type
// Deduct based on billing type
if
isSubscriptionBilling
{
if
isSubscriptionBilling
{
if
cost
.
TotalCost
>
0
{
if
cost
.
TotalCost
>
0
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment