Commit 642842c2 authored by shaw's avatar shaw
Browse files

First commit

parent 569f4882
package model
import (
"time"
)
type AccountGroup struct {
AccountID int64 `gorm:"primaryKey" json:"account_id"`
GroupID int64 `gorm:"primaryKey" json:"group_id"`
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 关联
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (AccountGroup) TableName() string {
return "account_groups"
}
package model
import (
"time"
"gorm.io/gorm"
)
type ApiKey struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
Name string `gorm:"size:100;not null" json:"name"`
GroupID *int64 `gorm:"index" json:"group_id"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (ApiKey) TableName() string {
return "api_keys"
}
// IsActive 检查是否激活
func (k *ApiKey) IsActive() bool {
return k.Status == "active"
}
package model
import (
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
type Group struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
}
func (Group) TableName() string {
return "groups"
}
// IsActive 检查是否激活
func (g *Group) IsActive() bool {
return g.Status == "active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
// HasDailyLimit 检查是否有日限额
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
// HasWeeklyLimit 检查是否有周限额
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
// HasMonthlyLimit 检查是否有月限额
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}
package model
import (
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&User{},
&ApiKey{},
&Group{},
&Account{},
&AccountGroup{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&Setting{},
&UserSubscription{},
)
}
// 状态常量
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// 角色常量
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// 平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// 账号类型常量
const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeApiKey = "apikey" // API Key类型账号
)
// 卡密类型常量
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// 管理员调整类型常量
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
package model
import (
"fmt"
"time"
"gorm.io/gorm"
)
type Proxy struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
Host string `gorm:"size:255;not null" json:"host"`
Port int `gorm:"not null" json:"port"`
Username string `gorm:"size:100" json:"username"`
Password string `gorm:"size:100" json:"-"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (Proxy) TableName() string {
return "proxies"
}
// IsActive 检查是否激活
func (p *Proxy) IsActive() bool {
return p.Status == "active"
}
// URL 返回代理URL
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
// ProxyWithAccountCount extends Proxy with account count information
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}
package model
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (RedeemCode) TableName() string {
return "redeem_codes"
}
// IsUsed 检查是否已使用
func (r *RedeemCode) IsUsed() bool {
return r.Status == "used"
}
// CanUse 检查是否可以使用
func (r *RedeemCode) CanUse() bool {
return r.Status == "unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
package model
import (
"time"
)
// Setting 系统设置模型(Key-Value存储)
type Setting struct {
ID int64 `gorm:"primaryKey" json:"id"`
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"`
Value string `gorm:"type:text;not null" json:"value"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
}
func (Setting) TableName() string {
return "settings"
}
// 设置Key常量
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
SettingKeyContactInfo = "contact_info" // 客服联系方式
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
)
// SystemSettings 系统设置结构体(用于API响应)
type SystemSettings struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// PublicSettings 公开设置(无需登录即可获取)
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
Version string `json:"version"`
}
package model
import (
"time"
)
// 消费类型常量
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
AccountID int64 `gorm:"index;not null" json:"account_id"`
RequestID string `gorm:"size:64" json:"request_id"`
Model string `gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID *int64 `gorm:"index" json:"group_id"`
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
// Token使用量(4类)
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用(USD)
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
// 元数据
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
Stream bool `gorm:"default:false;not null" json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func (UsageLog) TableName() string {
return "usage_logs"
}
// TotalTokens 总token数
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}
package model
import (
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
}
func (User) TableName() string {
return "users"
}
// IsAdmin 检查是否管理员
func (u *User) IsAdmin() bool {
return u.Role == "admin"
}
// IsActive 检查是否激活
func (u *User) IsActive() bool {
return u.Status == "active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return !isExclusive
}
// SetPassword 设置密码(哈希存储)
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
// CheckPassword 验证密码
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
return err == nil
}
package model
import (
"time"
)
// 订阅状态常量
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// UserSubscription 用户订阅模型
type UserSubscription struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
GroupID int64 `gorm:"index;not null" json:"group_id"`
// 订阅有效期
StartsAt time.Time `gorm:"not null" json:"starts_at"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended
// 滑动窗口起始时间(nil = 未激活)
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
// 当前窗口已用额度(USD,基于 total_cost 计算)
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
// 管理员分配信息
AssignedBy *int64 `gorm:"index" json:"assigned_by"`
AssignedAt time.Time `gorm:"not null" json:"assigned_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
}
func (UserSubscription) TableName() string {
return "user_subscriptions"
}
// IsActive 检查订阅是否有效(状态为active且未过期)
func (s *UserSubscription) IsActive() bool {
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
}
// IsExpired 检查订阅是否已过期
func (s *UserSubscription) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
// DaysRemaining 返回订阅剩余天数
func (s *UserSubscription) DaysRemaining() int {
if s.IsExpired() {
return 0
}
return int(time.Until(s.ExpiresAt).Hours() / 24)
}
// IsWindowActivated 检查窗口是否已激活
func (s *UserSubscription) IsWindowActivated() bool {
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
}
// NeedsDailyReset 检查日窗口是否需要重置
func (s *UserSubscription) NeedsDailyReset() bool {
if s.DailyWindowStart == nil {
return false
}
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
}
// NeedsWeeklyReset 检查周窗口是否需要重置
func (s *UserSubscription) NeedsWeeklyReset() bool {
if s.WeeklyWindowStart == nil {
return false
}
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
}
// NeedsMonthlyReset 检查月窗口是否需要重置
func (s *UserSubscription) NeedsMonthlyReset() bool {
if s.MonthlyWindowStart == nil {
return false
}
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
}
// DailyResetTime 返回日窗口重置时间
func (s *UserSubscription) DailyResetTime() *time.Time {
if s.DailyWindowStart == nil {
return nil
}
t := s.DailyWindowStart.Add(24 * time.Hour)
return &t
}
// WeeklyResetTime 返回周窗口重置时间
func (s *UserSubscription) WeeklyResetTime() *time.Time {
if s.WeeklyWindowStart == nil {
return nil
}
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
return &t
}
// MonthlyResetTime 返回月窗口重置时间
func (s *UserSubscription) MonthlyResetTime() *time.Time {
if s.MonthlyWindowStart == nil {
return nil
}
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
return &t
}
// CheckDailyLimit 检查是否超出日限额
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
if !group.HasDailyLimit() {
return true // 无限制
}
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
}
// CheckWeeklyLimit 检查是否超出周限额
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
if !group.HasWeeklyLimit() {
return true // 无限制
}
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
}
// CheckMonthlyLimit 检查是否超出月限额
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
if !group.HasMonthlyLimit() {
return true // 无限制
}
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
}
// CheckAllLimits 检查所有限额
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
daily = s.CheckDailyLimit(group, additionalCost)
weekly = s.CheckWeeklyLimit(group, additionalCost)
monthly = s.CheckMonthlyLimit(group, additionalCost)
return
}
package oauth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// Claude OAuth Constants (from CRS project)
const (
// OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
// Scopes
ScopeProfile = "user:profile"
ScopeInference = "user:inference"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
Scope string `json:"scope"`
ProxyURL string `json:"proxy_url,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
}
// TokenResponse represents the token response from OAuth provider
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"`
}
// OrgInfo represents organization info from OAuth response
type OrgInfo struct {
UUID string `json:"uuid"`
}
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
UUID string `json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
}
package response
import (
"math"
"net/http"
"github.com/gin-gonic/gin"
)
// Response 标准API响应格式
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
type PaginatedData struct {
Items interface{} `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
}
// Success 返回成功响应
func Success(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Created 返回创建成功响应
func Created(c *gin.Context, data interface{}) {
c.JSON(http.StatusCreated, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
})
}
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)
}
// Unauthorized 返回401错误
func Unauthorized(c *gin.Context, message string) {
Error(c, http.StatusUnauthorized, message)
}
// Forbidden 返回403错误
func Forbidden(c *gin.Context, message string) {
Error(c, http.StatusForbidden, message)
}
// NotFound 返回404错误
func NotFound(c *gin.Context, message string) {
Error(c, http.StatusNotFound, message)
}
// InternalError 返回500错误
func InternalError(c *gin.Context, message string) {
Error(c, http.StatusInternalServerError, message)
}
// Paginated 返回分页数据
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
}
Success(c, PaginatedData{
Items: items,
Total: total,
Page: page,
PageSize: pageSize,
Pages: pages,
})
}
// PaginationResult 分页结果(与repository.PaginationResult兼容)
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// PaginatedWithResult 使用PaginationResult返回分页数据
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
if pagination == nil {
Success(c, PaginatedData{
Items: items,
Total: 0,
Page: 1,
PageSize: 20,
Pages: 1,
})
return
}
Success(c, PaginatedData{
Items: items,
Total: pagination.Total,
Page: pagination.Page,
PageSize: pagination.PageSize,
Pages: pagination.Pages,
})
}
// ParsePagination 解析分页参数
func ParsePagination(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if p := c.Query("page"); p != "" {
if val, err := parseInt(p); err == nil && val > 0 {
page = val
}
}
// 支持 page_size 和 limit 两种参数名
if ps := c.Query("page_size"); ps != "" {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
pageSize = val
}
} else if l := c.Query("limit"); l != "" {
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
pageSize = val
}
}
return page, pageSize
}
func parseInt(s string) (int, error) {
var result int
for _, c := range s {
if c < '0' || c > '9' {
return 0, nil
}
result = result*10 + int(c-'0')
}
return result, nil
}
// Package timezone provides global timezone management for the application.
// Similar to PHP's date_default_timezone_set, this package allows setting
// a global timezone that affects all time.Now() calls.
package timezone
import (
"fmt"
"log"
"time"
)
var (
// location is the global timezone location
location *time.Location
// tzName stores the timezone name for logging/debugging
tzName string
)
// Init initializes the global timezone setting.
// This should be called once at application startup.
// Example timezone values: "Asia/Shanghai", "America/New_York", "UTC"
func Init(tz string) error {
if tz == "" {
tz = "Asia/Shanghai" // Default timezone
}
loc, err := time.LoadLocation(tz)
if err != nil {
return fmt.Errorf("invalid timezone %q: %w", tz, err)
}
// Set the global Go time.Local to our timezone
// This affects time.Now() throughout the application
time.Local = loc
location = loc
tzName = tz
log.Printf("Timezone initialized: %s (UTC offset: %s)", tz, getUTCOffset(loc))
return nil
}
// getUTCOffset returns the current UTC offset for a location
func getUTCOffset(loc *time.Location) string {
_, offset := time.Now().In(loc).Zone()
hours := offset / 3600
minutes := (offset % 3600) / 60
if minutes < 0 {
minutes = -minutes
}
sign := "+"
if hours < 0 {
sign = "-"
hours = -hours
}
return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes)
}
// Now returns the current time in the configured timezone.
// This is equivalent to time.Now() after Init() is called,
// but provided for explicit timezone-aware code.
func Now() time.Time {
if location == nil {
return time.Now()
}
return time.Now().In(location)
}
// Location returns the configured timezone location.
func Location() *time.Location {
if location == nil {
return time.Local
}
return location
}
// Name returns the configured timezone name.
func Name() string {
if tzName == "" {
return "Local"
}
return tzName
}
// StartOfDay returns the start of the given day (00:00:00) in the configured timezone.
func StartOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}
// Today returns the start of today (00:00:00) in the configured timezone.
func Today() time.Time {
return StartOfDay(Now())
}
// EndOfDay returns the end of the given day (23:59:59.999999999) in the configured timezone.
func EndOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, loc)
}
// StartOfWeek returns the start of the week (Monday 00:00:00) for the given time.
func StartOfWeek(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
weekday := int(t.Weekday())
if weekday == 0 {
weekday = 7 // Sunday is day 7
}
return time.Date(t.Year(), t.Month(), t.Day()-weekday+1, 0, 0, 0, 0, loc)
}
// StartOfMonth returns the start of the month (1st day 00:00:00) for the given time.
func StartOfMonth(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
}
// ParseInLocation parses a time string in the configured timezone.
func ParseInLocation(layout, value string) (time.Time, error) {
return time.ParseInLocation(layout, value, Location())
}
package timezone
import (
"testing"
"time"
)
func TestInit(t *testing.T) {
// Test with valid timezone
err := Init("Asia/Shanghai")
if err != nil {
t.Fatalf("Init failed with valid timezone: %v", err)
}
// Verify time.Local was set
if time.Local.String() != "Asia/Shanghai" {
t.Errorf("time.Local not set correctly, got %s", time.Local.String())
}
// Verify our location variable
if Location().String() != "Asia/Shanghai" {
t.Errorf("Location() not set correctly, got %s", Location().String())
}
// Test Name()
if Name() != "Asia/Shanghai" {
t.Errorf("Name() not set correctly, got %s", Name())
}
}
func TestInitInvalidTimezone(t *testing.T) {
err := Init("Invalid/Timezone")
if err == nil {
t.Error("Init should fail with invalid timezone")
}
}
func TestTimeNowAffected(t *testing.T) {
// Reset to UTC first
Init("UTC")
utcNow := time.Now()
// Switch to Shanghai (UTC+8)
Init("Asia/Shanghai")
shanghaiNow := time.Now()
// The times should be the same instant, but different timezone representation
// Shanghai should be 8 hours ahead in display
_, utcOffset := utcNow.Zone()
_, shanghaiOffset := shanghaiNow.Zone()
expectedDiff := 8 * 3600 // 8 hours in seconds
actualDiff := shanghaiOffset - utcOffset
if actualDiff != expectedDiff {
t.Errorf("Timezone offset difference incorrect: expected %d, got %d", expectedDiff, actualDiff)
}
}
func TestToday(t *testing.T) {
Init("Asia/Shanghai")
today := Today()
now := Now()
// Today should be at 00:00:00
if today.Hour() != 0 || today.Minute() != 0 || today.Second() != 0 {
t.Errorf("Today() not at start of day: %v", today)
}
// Today should be same date as now
if today.Year() != now.Year() || today.Month() != now.Month() || today.Day() != now.Day() {
t.Errorf("Today() date mismatch: today=%v, now=%v", today, now)
}
}
func TestStartOfDay(t *testing.T) {
Init("Asia/Shanghai")
// Create a time at 15:30:45
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
startOfDay := StartOfDay(testTime)
expected := time.Date(2024, 6, 15, 0, 0, 0, 0, Location())
if !startOfDay.Equal(expected) {
t.Errorf("StartOfDay incorrect: expected %v, got %v", expected, startOfDay)
}
}
func TestTruncateVsStartOfDay(t *testing.T) {
// This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code
Init("Asia/Shanghai")
now := Now()
// Truncate operates on UTC, not local time
truncated := now.Truncate(24 * time.Hour)
// StartOfDay operates on local time
startOfDay := StartOfDay(now)
// These will likely be different for non-UTC timezones
t.Logf("Now: %v", now)
t.Logf("Truncate(24h): %v", truncated)
t.Logf("StartOfDay: %v", startOfDay)
// The truncated time may not be at local midnight
// StartOfDay is always at local midnight
if startOfDay.Hour() != 0 {
t.Errorf("StartOfDay should be at hour 0, got %d", startOfDay.Hour())
}
}
func TestDSTAwareness(t *testing.T) {
// Test with a timezone that has DST (America/New_York)
err := Init("America/New_York")
if err != nil {
t.Skipf("America/New_York timezone not available: %v", err)
}
// Just verify it doesn't crash
_ = Today()
_ = Now()
_ = StartOfDay(Now())
}
package repository
import (
"context"
"sub2api/internal/model"
"time"
"gorm.io/gorm"
)
type AccountRepository struct {
db *gorm.DB
}
func NewAccountRepository(db *gorm.DB) *AccountRepository {
return &AccountRepository{db: db}
}
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error
}
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error
if err != nil {
return nil, err
}
// 填充 GroupIDs 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
}
return &account, nil
}
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error
}
func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 再删除账号
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
}
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) {
var accounts []model.Account
var total int64
db := r.db.WithContext(ctx).Model(&model.Account{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
}
if accountType != "" {
db = db.Where("type = ?", accountType)
}
if status != "" {
db = db.Where("status = ?", status)
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("name ILIKE ?", searchPattern)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
return nil, nil, err
}
// 填充每个 Account 的 GroupIDs 虚拟字段
for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
}
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return accounts, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
}
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"status": model.StatusError,
"error_message": errorMsg,
}).Error
}
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
}
return r.db.WithContext(ctx).Create(ag).Error
}
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error
}
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID).
Find(&groups).Error
return groups, err
}
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 添加新绑定
if len(groupIDs) > 0 {
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1, // 使用索引作为优先级
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
}
return nil
}
// ListSchedulable 获取所有可调度的账号
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
// ListSchedulableByGroupID 按组获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
}
// SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
}
// SetOverloaded 标记账号为过载状态(529)
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("overload_until", until).Error
}
// ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"rate_limited_at": nil,
"rate_limit_reset_at": nil,
"overload_until": nil,
}).Error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]interface{}{
"session_window_status": status,
}
if start != nil {
updates["session_window_start"] = start
}
if end != nil {
updates["session_window_end"] = end
}
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error
}
// SetSchedulable 设置账号的调度开关
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
}
package repository
import (
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type ApiKeyRepository struct {
db *gorm.DB
}
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository {
return &ApiKeyRepository{db: db}
}
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Create(key).Error
}
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
if err != nil {
return nil, err
}
return &key, nil
}
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
if err != nil {
return nil, err
}
return &apiKey, nil
}
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
}
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
}
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
var keys []model.ApiKey
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return keys, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
return count > 0, err
}
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
var keys []model.ApiKey
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Preload("User").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return keys, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
if userID > 0 {
db = db.Where("user_id = ?", userID)
}
if keyword != "" {
searchPattern := "%" + keyword + "%"
db = db.Where("name ILIKE ?", searchPattern)
}
if err := db.Limit(limit).Order("id DESC").Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("group_id = ?", groupID).
Update("group_id", nil)
return result.RowsAffected, result.Error
}
// CountByGroupID 获取分组的 API Key 数量
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
package repository
import (
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type GroupRepository struct {
db *gorm.DB
}
func NewGroupRepository(db *gorm.DB) *GroupRepository {
return &GroupRepository{db: db}
}
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Create(group).Error
}
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
var group model.Group
err := r.db.WithContext(ctx).First(&group, id).Error
if err != nil {
return nil, err
}
return &group, nil
}
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Save(group).Error
}
func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
}
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) {
var groups []model.Group
var total int64
db := r.db.WithContext(ctx).Model(&model.Group{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
}
if status != "" {
db = db.Where("status = ?", status)
}
if isExclusive != nil {
db = db.Where("is_exclusive = ?", *isExclusive)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id ASC").Find(&groups).Error; err != nil {
return nil, nil, err
}
// 获取每个分组的账号数量
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return groups, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
}
return groups, nil
}
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
}
return groups, nil
}
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
return count > 0, err
}
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
return result.RowsAffected, result.Error
}
// DB 返回底层数据库连接,用于事务处理
func (r *GroupRepository) DB() *gorm.DB {
return r.db
}
package repository
import (
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type ProxyRepository struct {
db *gorm.DB
}
func NewProxyRepository(db *gorm.DB) *ProxyRepository {
return &ProxyRepository{db: db}
}
func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error
}
func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
var proxy model.Proxy
err := r.db.WithContext(ctx).First(&proxy, id).Error
if err != nil {
return nil, err
}
return &proxy, nil
}
func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error
}
func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
}
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) {
var proxies []model.Proxy
var total int64
db := r.db.WithContext(ctx).Model(&model.Proxy{})
// Apply filters
if protocol != "" {
db = db.Where("protocol = ?", protocol)
}
if status != "" {
db = db.Where("status = ?", status)
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("name ILIKE ?", searchPattern)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&proxies).Error; err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return proxies, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
var proxies []model.Proxy
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
return proxies, err
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}).
Where("proxy_id = ?", proxyID).
Count(&count).Error
return count, err
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
type result struct {
ProxyID int64 `gorm:"column:proxy_id"`
Count int64 `gorm:"column:count"`
}
var results []result
err := r.db.WithContext(ctx).
Model(&model.Account{}).
Select("proxy_id, COUNT(*) as count").
Where("proxy_id IS NOT NULL").
Group("proxy_id").
Scan(&results).Error
if err != nil {
return nil, err
}
counts := make(map[int64]int64)
for _, r := range results {
counts[r.ProxyID] = r.Count
}
return counts, nil
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
var proxies []model.Proxy
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Order("created_at DESC").
Find(&proxies).Error
if err != nil {
return nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, err
}
// Build result with account counts
result := make([]model.ProxyWithAccountCount, len(proxies))
for i, proxy := range proxies {
result[i] = model.ProxyWithAccountCount{
Proxy: proxy,
AccountCount: counts[proxy.ID],
}
}
return result, nil
}
package repository
import (
"context"
"sub2api/internal/model"
"time"
"gorm.io/gorm"
)
type RedeemCodeRepository struct {
db *gorm.DB
}
func NewRedeemCodeRepository(db *gorm.DB) *RedeemCodeRepository {
return &RedeemCodeRepository{db: db}
}
func (r *RedeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error
}
func (r *RedeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error
}
func (r *RedeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
var code model.RedeemCode
err := r.db.WithContext(ctx).First(&code, id).Error
if err != nil {
return nil, err
}
return &code, nil
}
func (r *RedeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
var redeemCode model.RedeemCode
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
if err != nil {
return nil, err
}
return &redeemCode, nil
}
func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
}
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) {
var codes []model.RedeemCode
var total int64
db := r.db.WithContext(ctx).Model(&model.RedeemCode{})
// Apply filters
if codeType != "" {
db = db.Where("type = ?", codeType)
}
if status != "" {
db = db.Where("status = ?", status)
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("code ILIKE ?", searchPattern)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Preload("User").Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&codes).Error; err != nil {
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return codes, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *RedeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error
}
func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused).
Updates(map[string]interface{}{
"status": model.StatusUsed,
"used_by": userID,
"used_at": now,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 兑换码不存在或已被使用
}
return nil
}
// ListByUser returns all redeem codes used by a specific user
func (r *RedeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode
if limit <= 0 {
limit = 10
}
err := r.db.WithContext(ctx).
Preload("Group").
Where("used_by = ?", userID).
Order("used_at DESC").
Limit(limit).
Find(&codes).Error
if err != nil {
return nil, err
}
return codes, nil
}
package repository
import (
"gorm.io/gorm"
)
// Repositories 所有仓库的集合
type Repositories struct {
User *UserRepository
ApiKey *ApiKeyRepository
Group *GroupRepository
Account *AccountRepository
Proxy *ProxyRepository
RedeemCode *RedeemCodeRepository
UsageLog *UsageLogRepository
Setting *SettingRepository
UserSubscription *UserSubscriptionRepository
}
// NewRepositories 创建所有仓库
func NewRepositories(db *gorm.DB) *Repositories {
return &Repositories{
User: NewUserRepository(db),
ApiKey: NewApiKeyRepository(db),
Group: NewGroupRepository(db),
Account: NewAccountRepository(db),
Proxy: NewProxyRepository(db),
RedeemCode: NewRedeemCodeRepository(db),
UsageLog: NewUsageLogRepository(db),
Setting: NewSettingRepository(db),
UserSubscription: NewUserSubscriptionRepository(db),
}
}
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment