Commit 2bd288a6 authored by song's avatar song
Browse files

Merge branch 'main' into feature/antigravity_auth

parents 234e98f1 c01db6b1
...@@ -283,6 +283,16 @@ npm run dev ...@@ -283,6 +283,16 @@ npm run dev
--- ---
## 简易模式
简易模式适合个人开发者或内部团队快速使用,不依赖完整 SaaS 功能。
- 启用方式:设置环境变量 `RUN_MODE=simple`
- 功能差异:隐藏 SaaS 相关功能,跳过计费流程
- 安全注意事项:生产环境需同时设置 `SIMPLE_MODE_CONFIRM=true` 才允许启动
---
## 项目结构 ## 项目结构
``` ```
......
...@@ -107,6 +107,14 @@ func runSetupServer() { ...@@ -107,6 +107,14 @@ func runSetupServer() {
} }
func runMainServer() { func runMainServer() {
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
if cfg.RunMode == config.RunModeSimple {
log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED")
}
buildInfo := handler.BuildInfo{ buildInfo := handler.BuildInfo{
Version: Version, Version: Version,
BuildType: BuildType, BuildType: BuildType,
......
...@@ -49,7 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -49,7 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
emailQueueService := service.ProvideEmailQueueService(emailService) emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
userService := service.NewUserService(userRepository) userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(authService, userService) authHandler := handler.NewAuthHandler(configConfig, authService, userService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db) apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db) groupRepository := repository.NewGroupRepository(db)
...@@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db) redeemCodeRepository := repository.NewRedeemCodeRepository(db)
billingCache := repository.NewBillingCache(client) billingCache := repository.NewBillingCache(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(client) redeemCache := repository.NewRedeemCache(client)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
...@@ -132,7 +132,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -132,7 +132,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService) apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
httpServer := server.ProvideHTTPServer(configConfig, engine) httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
......
...@@ -7,6 +7,11 @@ import ( ...@@ -7,6 +7,11 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
) )
const (
RunModeStandard = "standard"
RunModeSimple = "simple"
)
type Config struct { type Config struct {
Server ServerConfig `mapstructure:"server"` Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"` Database DatabaseConfig `mapstructure:"database"`
...@@ -17,6 +22,7 @@ type Config struct { ...@@ -17,6 +22,7 @@ type Config struct {
Pricing PricingConfig `mapstructure:"pricing"` Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"` Gateway GatewayConfig `mapstructure:"gateway"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"` Gemini GeminiConfig `mapstructure:"gemini"`
} }
...@@ -135,6 +141,16 @@ type RateLimitConfig struct { ...@@ -135,6 +141,16 @@ type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
} }
func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
case RunModeStandard, RunModeSimple:
return normalized
default:
return RunModeStandard
}
}
func Load() (*Config, error) { func Load() (*Config, error) {
viper.SetConfigName("config") viper.SetConfigName("config")
viper.SetConfigType("yaml") viper.SetConfigType("yaml")
...@@ -161,6 +177,8 @@ func Load() (*Config, error) { ...@@ -161,6 +177,8 @@ func Load() (*Config, error) {
return nil, fmt.Errorf("unmarshal config error: %w", err) return nil, fmt.Errorf("unmarshal config error: %w", err)
} }
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
if err := cfg.Validate(); err != nil { if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err) return nil, fmt.Errorf("validate config error: %w", err)
} }
...@@ -169,6 +187,8 @@ func Load() (*Config, error) { ...@@ -169,6 +187,8 @@ func Load() (*Config, error) {
} }
func setDefaults() { func setDefaults() {
viper.SetDefault("run_mode", RunModeStandard)
// Server // Server
viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.port", 8080) viper.SetDefault("server.port", 8080)
......
package config
import "testing"
func TestNormalizeRunMode(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"simple", "simple"},
{"SIMPLE", "simple"},
{"standard", "standard"},
{"invalid", "standard"},
{"", "standard"},
}
for _, tt := range tests {
result := NormalizeRunMode(tt.input)
if result != tt.expected {
t.Errorf("NormalizeRunMode(%q) = %q, want %q", tt.input, result, tt.expected)
}
}
}
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -11,13 +12,15 @@ import ( ...@@ -11,13 +12,15 @@ import (
// AuthHandler handles authentication-related requests // AuthHandler handles authentication-related requests
type AuthHandler struct { type AuthHandler struct {
cfg *config.Config
authService *service.AuthService authService *service.AuthService
userService *service.UserService userService *service.UserService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler { func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
cfg: cfg,
authService: authService, authService: authService,
userService: userService, userService: userService,
} }
...@@ -157,5 +160,15 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { ...@@ -157,5 +160,15 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
return return
} }
response.Success(c, dto.UserFromService(user)) type UserResponse struct {
*dto.User
RunMode string `json:"run_mode"`
}
runMode := config.RunModeStandard
if h.cfg != nil {
runMode = h.cfg.RunMode
}
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
} }
...@@ -30,6 +30,11 @@ func AutoMigrate(db *gorm.DB) error { ...@@ -30,6 +30,11 @@ func AutoMigrate(db *gorm.DB) error {
return err return err
} }
// 创建默认分组(简易模式支持)
if err := ensureDefaultGroups(db); err != nil {
return err
}
// 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败) // 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败)
return fixInvalidExpiresAt(db) return fixInvalidExpiresAt(db)
} }
...@@ -47,3 +52,55 @@ func fixInvalidExpiresAt(db *gorm.DB) error { ...@@ -47,3 +52,55 @@ func fixInvalidExpiresAt(db *gorm.DB) error {
} }
return nil return nil
} }
// ensureDefaultGroups 确保默认分组存在(简易模式支持)
// 为每个平台创建一个默认分组,配置最大权限以确保简易模式下不受限制
func ensureDefaultGroups(db *gorm.DB) error {
defaultGroups := []struct {
name string
platform string
description string
}{
{
name: "anthropic-default",
platform: "anthropic",
description: "Default group for Anthropic accounts (Simple Mode)",
},
{
name: "openai-default",
platform: "openai",
description: "Default group for OpenAI accounts (Simple Mode)",
},
{
name: "gemini-default",
platform: "gemini",
description: "Default group for Gemini accounts (Simple Mode)",
},
}
for _, dg := range defaultGroups {
var count int64
if err := db.Model(&groupModel{}).Where("name = ?", dg.name).Count(&count).Error; err != nil {
return err
}
if count == 0 {
group := &groupModel{
Name: dg.name,
Description: dg.description,
Platform: dg.platform,
RateMultiplier: 1.0,
IsExclusive: false,
Status: "active",
SubscriptionType: "standard",
}
if err := db.Create(group).Error; err != nil {
log.Printf("[AutoMigrate] Failed to create default group %s: %v", dg.name, err)
return err
}
log.Printf("[AutoMigrate] Created default group: %s (platform: %s)", dg.name, dg.platform)
}
}
return nil
}
...@@ -82,8 +82,9 @@ func (s *GroupRepoSuite) TestList() { ...@@ -82,8 +82,9 @@ func (s *GroupRepoSuite) TestList() {
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
s.Require().Len(groups, 2) // 3 default groups + 2 test groups = 5 total
s.Require().Equal(int64(2), page.Total) s.Require().Len(groups, 5)
s.Require().Equal(int64(5), page.Total)
} }
func (s *GroupRepoSuite) TestListWithFilters_Platform() { func (s *GroupRepoSuite) TestListWithFilters_Platform() {
...@@ -92,8 +93,12 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { ...@@ -92,8 +93,12 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) // 1 default openai group + 1 test openai group = 2 total
s.Require().Equal(service.PlatformOpenAI, groups[0].Platform) s.Require().Len(groups, 2)
// Verify all groups are OpenAI platform
for _, g := range groups {
s.Require().Equal(service.PlatformOpenAI, g.Platform)
}
} }
func (s *GroupRepoSuite) TestListWithFilters_Status() { func (s *GroupRepoSuite) TestListWithFilters_Status() {
...@@ -151,8 +156,17 @@ func (s *GroupRepoSuite) TestListActive() { ...@@ -151,8 +156,17 @@ func (s *GroupRepoSuite) TestListActive() {
groups, err := s.repo.ListActive(s.ctx) groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
s.Require().Len(groups, 1) // 3 default groups (all active) + 1 test active group = 4 total
s.Require().Equal("active1", groups[0].Name) s.Require().Len(groups, 4)
// Verify our test group is in the results
var found bool
for _, g := range groups {
if g.Name == "active1" {
found = true
break
}
}
s.Require().True(found, "active1 group should be in results")
} }
func (s *GroupRepoSuite) TestListActiveByPlatform() { func (s *GroupRepoSuite) TestListActiveByPlatform() {
...@@ -162,8 +176,17 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() { ...@@ -162,8 +176,17 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic) groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform") s.Require().NoError(err, "ListActiveByPlatform")
s.Require().Len(groups, 1) // 1 default anthropic group + 1 test active anthropic group = 2 total
s.Require().Equal("g1", groups[0].Name) s.Require().Len(groups, 2)
// Verify our test group is in the results
var found bool
for _, g := range groups {
if g.Name == "g1" {
found = true
break
}
}
s.Require().True(found, "g1 group should be in results")
} }
// --- ExistsByName --- // --- ExistsByName ---
......
...@@ -59,7 +59,8 @@ func TestAPIContracts(t *testing.T) { ...@@ -59,7 +59,8 @@ func TestAPIContracts(t *testing.T) {
"status": "active", "status": "active",
"allowed_groups": null, "allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z",
"run_mode": "standard"
} }
}`, }`,
}, },
...@@ -369,6 +370,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -369,6 +370,7 @@ func newContractDeps(t *testing.T) *contractDeps {
Default: config.DefaultConfig{ Default: config.DefaultConfig{
ApiKeyPrefix: "sk-", ApiKeyPrefix: "sk-",
}, },
RunMode: config.RunModeStandard,
} }
userService := service.NewUserService(userRepo) userService := service.NewUserService(userRepo)
...@@ -380,7 +382,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -380,7 +382,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(nil, userService) authHandler := handler.NewAuthHandler(cfg, nil, userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil)
......
...@@ -36,7 +36,7 @@ func ProvideRouter( ...@@ -36,7 +36,7 @@ func ProvideRouter(
r := gin.New() r := gin.New()
r.Use(middleware2.Recovery()) r.Use(middleware2.Recovery())
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService) return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
} }
// ProvideHTTPServer 提供 HTTP 服务器 // ProvideHTTPServer 提供 HTTP 服务器
......
...@@ -5,18 +5,19 @@ import ( ...@@ -5,18 +5,19 @@ import (
"log" "log"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件 // NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) ApiKeyAuthMiddleware { func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService)) return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
} }
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc { func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme) // 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
...@@ -85,6 +86,18 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -85,6 +86,18 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
return return
} }
if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
return
}
// 判断计费方式:订阅模式 vs 余额模式 // 判断计费方式:订阅模式 vs 余额模式
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -11,15 +12,15 @@ import ( ...@@ -11,15 +12,15 @@ import (
) )
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth. // ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService) gin.HandlerFunc { func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil) return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
} }
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors: // ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}} // {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
// //
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc { func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c) apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" { if apiKeyString == "" {
...@@ -50,6 +51,18 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -50,6 +51,18 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
return return
} }
// 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
return
}
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil { if isSubscriptionType && subscriptionService != nil {
subscription, err := subscriptionService.GetActiveSubscription( subscription, err := subscriptionService.GetActiveSubscription(
......
//go:build unit
package middleware
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := 1.0
group := &service.Group{
ID: 42,
Name: "sub",
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.ApiKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
if key != apiKey.Key {
return nil, service.ErrApiKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
ID: 55,
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: now.Add(24 * time.Hour),
DailyWindowStart: &now,
DailyUsageUSD: 10,
}
subscriptionRepo := &stubUserSubscriptionRepo{
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if userID != sub.UserID || groupID != sub.GroupID {
return nil, service.ErrSubscriptionNotFound
}
clone := *sub
return &clone, nil
},
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusTooManyRequests, w.Code)
require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED")
})
}
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
return router
}
type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if r.getByKey != nil {
return r.getByKey(ctx, key)
}
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
activateWindow func(ctx context.Context, id int64, start time.Time) error
resetDaily func(ctx context.Context, id int64, start time.Time) error
resetWeekly func(ctx context.Context, id int64, start time.Time) error
resetMonthly func(ctx context.Context, id int64, start time.Time) error
}
func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if r.getActive != nil {
return r.getActive(ctx, userID, groupID)
}
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
if r.updateStatus != nil {
return r.updateStatus(ctx, subscriptionID, status)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
if r.activateWindow != nil {
return r.activateWindow(ctx, id, start)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
if r.resetDaily != nil {
return r.resetDaily(ctx, id, newWindowStart)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
if r.resetWeekly != nil {
return r.resetWeekly(ctx, id, newWindowStart)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
if r.resetMonthly != nil {
return r.resetMonthly(ctx, id, newWindowStart)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}
package server package server
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/server/routes" "github.com/Wei-Shaw/sub2api/internal/server/routes"
...@@ -19,6 +20,7 @@ func SetupRouter( ...@@ -19,6 +20,7 @@ func SetupRouter(
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config,
) *gin.Engine { ) *gin.Engine {
// 应用中间件 // 应用中间件
r.Use(middleware2.Logger()) r.Use(middleware2.Logger())
...@@ -30,7 +32,7 @@ func SetupRouter( ...@@ -30,7 +32,7 @@ func SetupRouter(
} }
// 注册路由 // 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService) registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
return r return r
} }
...@@ -44,6 +46,7 @@ func registerRoutes( ...@@ -44,6 +46,7 @@ func registerRoutes(
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config,
) { ) {
// 通用路由(健康检查、状态等) // 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r) routes.RegisterCommonRoutes(r)
...@@ -55,5 +58,5 @@ func registerRoutes( ...@@ -55,5 +58,5 @@ func registerRoutes(
routes.RegisterAuthRoutes(v1, h, jwtAuth) routes.RegisterAuthRoutes(v1, h, jwtAuth)
routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
} }
package routes package routes
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -15,6 +16,7 @@ func RegisterGatewayRoutes( ...@@ -15,6 +16,7 @@ func RegisterGatewayRoutes(
apiKeyAuth middleware.ApiKeyAuthMiddleware, apiKeyAuth middleware.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config,
) { ) {
// API网关(Claude API兼容) // API网关(Claude API兼容)
gateway := r.Group("/v1") gateway := r.Group("/v1")
...@@ -30,7 +32,7 @@ func RegisterGatewayRoutes( ...@@ -30,7 +32,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta") gemini := r.Group("/v1beta")
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService)) gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
...@@ -54,7 +56,7 @@ func RegisterGatewayRoutes( ...@@ -54,7 +56,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta := r.Group("/antigravity/v1beta") antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService)) antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
......
...@@ -54,15 +54,23 @@ type UsageLogRepository interface { ...@@ -54,15 +54,23 @@ type UsageLogRepository interface {
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
} }
// usageCache 用于缓存usage数据 // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
type usageCache struct { type apiUsageCache struct {
data *UsageInfo response *ClaudeUsageResponse
timestamp time.Time
}
// windowStatsCache 缓存从本地数据库查询的窗口统计(requests, tokens, cost)
type windowStatsCache struct {
stats *WindowStats
timestamp time.Time timestamp time.Time
} }
var ( var (
usageCacheMap = sync.Map{} apiCacheMap = sync.Map{} // 缓存 API 响应
cacheTTL = 10 * time.Minute windowStatsCacheMap = sync.Map{} // 缓存窗口统计
apiCacheTTL = 10 * time.Minute
windowStatsCacheTTL = 1 * time.Minute
) )
// WindowStats 窗口期统计 // WindowStats 窗口期统计
...@@ -126,7 +134,7 @@ func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLog ...@@ -126,7 +134,7 @@ func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLog
} }
// GetUsage 获取账号使用量 // GetUsage 获取账号使用量
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),缓存10分钟 // OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope) // Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
// API Key账号: 不支持usage查询 // API Key账号: 不支持usage查询
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) { func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
...@@ -137,30 +145,34 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U ...@@ -137,30 +145,34 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
// 只有oauth类型账号可以通过API获取usage(有profile scope) // 只有oauth类型账号可以通过API获取usage(有profile scope)
if account.CanGetUsage() { if account.CanGetUsage() {
// 检查缓存 var apiResp *ClaudeUsageResponse
if cached, ok := usageCacheMap.Load(accountID); ok {
cache, ok := cached.(*usageCache) // 1. 检查 API 缓存(10 分钟)
if !ok { if cached, ok := apiCacheMap.Load(accountID); ok {
usageCacheMap.Delete(accountID) if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
} else if time.Since(cache.timestamp) < cacheTTL { apiResp = cache.response
return cache.data, nil
} }
} }
// 从API获取数据 // 2. 如果没有缓存,从 API 获取
usage, err := s.fetchOAuthUsage(ctx, account) if apiResp == nil {
if err != nil { apiResp, err = s.fetchOAuthUsageRaw(ctx, account)
return nil, err if err != nil {
return nil, err
}
// 缓存 API 响应
apiCacheMap.Store(accountID, &apiUsageCache{
response: apiResp,
timestamp: time.Now(),
})
} }
// 添加5h窗口统计数据 // 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds)
s.addWindowStats(ctx, account, usage) now := time.Now()
usage := s.buildUsageInfo(apiResp, &now)
// 缓存结果 // 4. 添加窗口统计(有独立缓存,1 分钟)
usageCacheMap.Store(accountID, &usageCache{ s.addWindowStats(ctx, account, usage)
data: usage,
timestamp: time.Now(),
})
return usage, nil return usage, nil
} }
...@@ -177,31 +189,54 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U ...@@ -177,31 +189,54 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type) return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
} }
// addWindowStats 为usage数据添加窗口期统计 // addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存(1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
if usage.FiveHour == nil { // 修复:即使 FiveHour 为 nil,也要尝试获取统计数据
// 因为 SevenDay/SevenDaySonnet 可能需要
if usage.FiveHour == nil && usage.SevenDay == nil && usage.SevenDaySonnet == nil {
return return
} }
// 使用session_window_start作为统计起始时间 // 检查窗口统计缓存(1 分钟)
var startTime time.Time var windowStats *WindowStats
if account.SessionWindowStart != nil { if cached, ok := windowStatsCacheMap.Load(account.ID); ok {
startTime = *account.SessionWindowStart if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
} else { windowStats = cache.stats
// 如果没有窗口信息,使用5小时前作为默认 }
startTime = time.Now().Add(-5 * time.Hour)
} }
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) // 如果没有缓存,从数据库查询
if err != nil { if windowStats == nil {
log.Printf("Failed to get window stats for account %d: %v", account.ID, err) var startTime time.Time
return if account.SessionWindowStart != nil {
startTime = *account.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
return
}
windowStats = &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
}
// 缓存窗口统计(1 分钟)
windowStatsCacheMap.Store(account.ID, &windowStatsCache{
stats: windowStats,
timestamp: time.Now(),
})
} }
usage.FiveHour.WindowStats = &WindowStats{ // 为 FiveHour 添加 WindowStats(5h 窗口统计)
Requests: stats.Requests, if usage.FiveHour != nil {
Tokens: stats.Tokens, usage.FiveHour.WindowStats = windowStats
Cost: stats.Cost,
} }
} }
...@@ -227,8 +262,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI ...@@ -227,8 +262,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
return stats, nil return stats, nil
} }
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量 // fetchOAuthUsageRaw Anthropic API 获取原始响应(不构建 UsageInfo)
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Account) (*UsageInfo, error) { func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
return nil, fmt.Errorf("no access token available") return nil, fmt.Errorf("no access token available")
...@@ -239,13 +274,7 @@ func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Acco ...@@ -239,13 +274,7 @@ func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Acco
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL) return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
if err != nil {
return nil, err
}
now := time.Now()
return s.buildUsageInfo(usageResp, &now), nil
} }
// parseTime 尝试多种格式解析时间 // parseTime 尝试多种格式解析时间
...@@ -270,20 +299,16 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA ...@@ -270,20 +299,16 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
UpdatedAt: updatedAt, UpdatedAt: updatedAt,
} }
// 5小时窗口 // 5小时窗口 - 始终创建对象(即使 ResetsAt 为空)
info.FiveHour = &UsageProgress{
Utilization: resp.FiveHour.Utilization,
}
if resp.FiveHour.ResetsAt != "" { if resp.FiveHour.ResetsAt != "" {
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil { if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
info.FiveHour = &UsageProgress{ info.FiveHour.ResetsAt = &fiveHourReset
Utilization: resp.FiveHour.Utilization, info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds())
ResetsAt: &fiveHourReset,
RemainingSeconds: int(time.Until(fiveHourReset).Seconds()),
}
} else { } else {
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err) log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
// 即使解析失败也返回utilization
info.FiveHour = &UsageProgress{
Utilization: resp.FiveHour.Utilization,
}
} }
} }
......
...@@ -609,12 +609,30 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -609,12 +609,30 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err return nil, err
} }
// 绑定分组 // 绑定分组
if len(input.GroupIDs) > 0 { groupIDs := input.GroupIDs
if err := s.accountRepo.BindGroups(ctx, account.ID, input.GroupIDs); err != nil { // 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID)
break
}
}
}
}
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
return nil, err return nil, err
} }
} }
return account, nil return account, nil
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"log" "log"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
) )
...@@ -32,14 +33,16 @@ type BillingCacheService struct { ...@@ -32,14 +33,16 @@ type BillingCacheService struct {
cache BillingCache cache BillingCache
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
cfg *config.Config
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository) *BillingCacheService { func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
return &BillingCacheService{ return &BillingCacheService{
cache: cache, cache: cache,
userRepo: userRepo, userRepo: userRepo,
subRepo: subRepo, subRepo: subRepo,
cfg: cfg,
} }
} }
...@@ -224,6 +227,11 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID ...@@ -224,6 +227,11 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// 余额模式:检查缓存余额 > 0 // 余额模式:检查缓存余额 > 0
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error { func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple {
return nil
}
// 判断计费模式 // 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
......
...@@ -357,7 +357,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -357,7 +357,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 2. 获取可调度账号列表(单平台) // 2. 获取可调度账号列表(单平台)
var accounts []Account var accounts []Account
var err error var err error
if groupID != nil { if s.cfg.RunMode == config.RunModeSimple {
// 简易模式:忽略 groupID,查询所有可用账号
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
...@@ -1226,6 +1229,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1226,6 +1229,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log.Printf("Create usage log failed: %v", err) log.Printf("Create usage log failed: %v", err)
} }
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
// 根据计费类型执行扣费 // 根据计费类型执行扣费
if isSubscriptionBilling { if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
...@@ -155,7 +156,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C ...@@ -155,7 +156,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
// 2. Get schedulable OpenAI accounts // 2. Get schedulable OpenAI accounts
var accounts []Account var accounts []Account
var err error var err error
if groupID != nil { // 简易模式:忽略分组限制,查询所有可用账号
if s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
...@@ -754,6 +758,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -754,6 +758,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
_ = s.usageLogRepo.Create(ctx, usageLog) _ = s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
// Deduct based on billing type // Deduct based on billing type
if isSubscriptionBilling { if isSubscriptionBilling {
if cost.TotalCost > 0 { if cost.TotalCost > 0 {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment