Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
642842c2
Commit
642842c2
authored
Dec 18, 2025
by
shaw
Browse files
First commit
parent
569f4882
Changes
201
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
201 of 201+
files are displayed.
Plain diff
Email patch
backend/internal/model/account_group.go
0 → 100644
View file @
642842c2
package
model
import
(
"time"
)
type
AccountGroup
struct
{
AccountID
int64
`gorm:"primaryKey" json:"account_id"`
GroupID
int64
`gorm:"primaryKey" json:"group_id"`
Priority
int
`gorm:"default:50;not null" json:"priority"`
// 分组内优先级
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
// 关联
Account
*
Account
`gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func
(
AccountGroup
)
TableName
()
string
{
return
"account_groups"
}
backend/internal/model/api_key.go
0 → 100644
View file @
642842c2
package
model
import
(
"time"
"gorm.io/gorm"
)
type
ApiKey
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
UserID
int64
`gorm:"index;not null" json:"user_id"`
Key
string
`gorm:"uniqueIndex;size:128;not null" json:"key"`
// sk-xxx
Name
string
`gorm:"size:100;not null" json:"name"`
GroupID
*
int64
`gorm:"index" json:"group_id"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
// 关联
User
*
User
`gorm:"foreignKey:UserID" json:"user,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func
(
ApiKey
)
TableName
()
string
{
return
"api_keys"
}
// IsActive 检查是否激活
func
(
k
*
ApiKey
)
IsActive
()
bool
{
return
k
.
Status
==
"active"
}
backend/internal/model/group.go
0 → 100644
View file @
642842c2
package
model
import
(
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const
(
SubscriptionTypeStandard
=
"standard"
// 标准计费模式(按余额扣费)
SubscriptionTypeSubscription
=
"subscription"
// 订阅模式(按限额控制)
)
type
Group
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Name
string
`gorm:"uniqueIndex;size:100;not null" json:"name"`
Description
string
`gorm:"type:text" json:"description"`
Platform
string
`gorm:"size:50;default:anthropic;not null" json:"platform"`
// anthropic/openai/gemini
RateMultiplier
float64
`gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive
bool
`gorm:"default:false;not null" json:"is_exclusive"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
// 订阅功能字段
SubscriptionType
string
`gorm:"size:20;default:standard;not null" json:"subscription_type"`
// standard/subscription
DailyLimitUSD
*
float64
`gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD
*
float64
`gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD
*
float64
`gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
// 关联
AccountGroups
[]
AccountGroup
`gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount
int64
`gorm:"-" json:"account_count,omitempty"`
}
func
(
Group
)
TableName
()
string
{
return
"groups"
}
// IsActive 检查是否激活
func
(
g
*
Group
)
IsActive
()
bool
{
return
g
.
Status
==
"active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func
(
g
*
Group
)
IsSubscriptionType
()
bool
{
return
g
.
SubscriptionType
==
SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func
(
g
*
Group
)
IsFreeSubscription
()
bool
{
return
g
.
IsSubscriptionType
()
&&
g
.
RateMultiplier
==
0
}
// HasDailyLimit 检查是否有日限额
func
(
g
*
Group
)
HasDailyLimit
()
bool
{
return
g
.
DailyLimitUSD
!=
nil
&&
*
g
.
DailyLimitUSD
>
0
}
// HasWeeklyLimit 检查是否有周限额
func
(
g
*
Group
)
HasWeeklyLimit
()
bool
{
return
g
.
WeeklyLimitUSD
!=
nil
&&
*
g
.
WeeklyLimitUSD
>
0
}
// HasMonthlyLimit 检查是否有月限额
func
(
g
*
Group
)
HasMonthlyLimit
()
bool
{
return
g
.
MonthlyLimitUSD
!=
nil
&&
*
g
.
MonthlyLimitUSD
>
0
}
backend/internal/model/model.go
0 → 100644
View file @
642842c2
package
model
import
(
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func
AutoMigrate
(
db
*
gorm
.
DB
)
error
{
return
db
.
AutoMigrate
(
&
User
{},
&
ApiKey
{},
&
Group
{},
&
Account
{},
&
AccountGroup
{},
&
Proxy
{},
&
RedeemCode
{},
&
UsageLog
{},
&
Setting
{},
&
UserSubscription
{},
)
}
// 状态常量
const
(
StatusActive
=
"active"
StatusDisabled
=
"disabled"
StatusError
=
"error"
StatusUnused
=
"unused"
StatusUsed
=
"used"
StatusExpired
=
"expired"
)
// 角色常量
const
(
RoleAdmin
=
"admin"
RoleUser
=
"user"
)
// 平台常量
const
(
PlatformAnthropic
=
"anthropic"
PlatformOpenAI
=
"openai"
PlatformGemini
=
"gemini"
)
// 账号类型常量
const
(
AccountTypeOAuth
=
"oauth"
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
"setup-token"
// Setup Token类型账号(inference only scope)
AccountTypeApiKey
=
"apikey"
// API Key类型账号
)
// 卡密类型常量
const
(
RedeemTypeBalance
=
"balance"
RedeemTypeConcurrency
=
"concurrency"
RedeemTypeSubscription
=
"subscription"
)
// 管理员调整类型常量
const
(
AdjustmentTypeAdminBalance
=
"admin_balance"
// 管理员调整余额
AdjustmentTypeAdminConcurrency
=
"admin_concurrency"
// 管理员调整并发数
)
backend/internal/model/proxy.go
0 → 100644
View file @
642842c2
package
model
import
(
"fmt"
"time"
"gorm.io/gorm"
)
type
Proxy
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Name
string
`gorm:"size:100;not null" json:"name"`
Protocol
string
`gorm:"size:20;not null" json:"protocol"`
// http/https/socks5
Host
string
`gorm:"size:255;not null" json:"host"`
Port
int
`gorm:"not null" json:"port"`
Username
string
`gorm:"size:100" json:"username"`
Password
string
`gorm:"size:100" json:"-"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
}
func
(
Proxy
)
TableName
()
string
{
return
"proxies"
}
// IsActive 检查是否激活
func
(
p
*
Proxy
)
IsActive
()
bool
{
return
p
.
Status
==
"active"
}
// URL 返回代理URL
func
(
p
*
Proxy
)
URL
()
string
{
if
p
.
Username
!=
""
&&
p
.
Password
!=
""
{
return
fmt
.
Sprintf
(
"%s://%s:%s@%s:%d"
,
p
.
Protocol
,
p
.
Username
,
p
.
Password
,
p
.
Host
,
p
.
Port
)
}
return
fmt
.
Sprintf
(
"%s://%s:%d"
,
p
.
Protocol
,
p
.
Host
,
p
.
Port
)
}
// ProxyWithAccountCount extends Proxy with account count information
type
ProxyWithAccountCount
struct
{
Proxy
AccountCount
int64
`json:"account_count"`
}
backend/internal/model/redeem_code.go
0 → 100644
View file @
642842c2
package
model
import
(
"crypto/rand"
"encoding/hex"
"time"
)
type
RedeemCode
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Code
string
`gorm:"uniqueIndex;size:32;not null" json:"code"`
Type
string
`gorm:"size:20;default:balance;not null" json:"type"`
// balance/concurrency/subscription
Value
float64
`gorm:"type:decimal(20,8);not null" json:"value"`
// 面值(USD)或并发数或有效天数
Status
string
`gorm:"size:20;default:unused;not null" json:"status"`
// unused/used
UsedBy
*
int64
`gorm:"index" json:"used_by"`
UsedAt
*
time
.
Time
`json:"used_at"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID
*
int64
`gorm:"index" json:"group_id"`
// 订阅分组ID (仅subscription类型使用)
ValidityDays
int
`gorm:"default:30" json:"validity_days"`
// 订阅有效天数 (仅subscription类型使用)
// 关联
User
*
User
`gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func
(
RedeemCode
)
TableName
()
string
{
return
"redeem_codes"
}
// IsUsed 检查是否已使用
func
(
r
*
RedeemCode
)
IsUsed
()
bool
{
return
r
.
Status
==
"used"
}
// CanUse 检查是否可以使用
func
(
r
*
RedeemCode
)
CanUse
()
bool
{
return
r
.
Status
==
"unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func
GenerateRedeemCode
()
string
{
b
:=
make
([]
byte
,
16
)
rand
.
Read
(
b
)
return
hex
.
EncodeToString
(
b
)
}
backend/internal/model/setting.go
0 → 100644
View file @
642842c2
package
model
import
(
"time"
)
// Setting 系统设置模型(Key-Value存储)
type
Setting
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Key
string
`gorm:"uniqueIndex;size:100;not null" json:"key"`
Value
string
`gorm:"type:text;not null" json:"value"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
}
func
(
Setting
)
TableName
()
string
{
return
"settings"
}
// 设置Key常量
const
(
// 注册设置
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost
=
"smtp_host"
// SMTP服务器地址
SettingKeySmtpPort
=
"smtp_port"
// SMTP端口
SettingKeySmtpUsername
=
"smtp_username"
// SMTP用户名
SettingKeySmtpPassword
=
"smtp_password"
// SMTP密码(加密存储)
SettingKeySmtpFrom
=
"smtp_from"
// 发件人地址
SettingKeySmtpFromName
=
"smtp_from_name"
// 发件人名称
SettingKeySmtpUseTLS
=
"smtp_use_tls"
// 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled
=
"turnstile_enabled"
// 是否启用 Turnstile 验证
SettingKeyTurnstileSiteKey
=
"turnstile_site_key"
// Turnstile Site Key
SettingKeyTurnstileSecretKey
=
"turnstile_secret_key"
// Turnstile Secret Key
// OEM设置
SettingKeySiteName
=
"site_name"
// 网站名称
SettingKeySiteLogo
=
"site_logo"
// 网站Logo (base64)
SettingKeySiteSubtitle
=
"site_subtitle"
// 网站副标题
SettingKeyApiBaseUrl
=
"api_base_url"
// API端点地址(用于客户端配置和导入)
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
// 默认配置
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
SettingKeyDefaultBalance
=
"default_balance"
// 新用户默认余额
)
// SystemSettings 系统设置结构体(用于API响应)
type
SystemSettings
struct
{
// 注册设置
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost
string
`json:"smtp_host"`
SmtpPort
int
`json:"smtp_port"`
SmtpUsername
string
`json:"smtp_username"`
SmtpPassword
string
`json:"smtp_password,omitempty"`
// 不返回明文密码
SmtpFrom
string
`json:"smtp_from_email"`
SmtpFromName
string
`json:"smtp_from_name"`
SmtpUseTLS
bool
`json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled
bool
`json:"turnstile_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key"`
TurnstileSecretKey
string
`json:"turnstile_secret_key,omitempty"`
// 不返回明文密钥
// OEM设置
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteSubtitle
string
`json:"site_subtitle"`
ApiBaseUrl
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
// 默认配置
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultBalance
float64
`json:"default_balance"`
}
// PublicSettings 公开设置(无需登录即可获取)
type
PublicSettings
struct
{
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
TurnstileEnabled
bool
`json:"turnstile_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteSubtitle
string
`json:"site_subtitle"`
ApiBaseUrl
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
Version
string
`json:"version"`
}
backend/internal/model/usage_log.go
0 → 100644
View file @
642842c2
package
model
import
(
"time"
)
// 消费类型常量
const
(
BillingTypeBalance
int8
=
0
// 钱包余额
BillingTypeSubscription
int8
=
1
// 订阅套餐
)
type
UsageLog
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
UserID
int64
`gorm:"index;not null" json:"user_id"`
ApiKeyID
int64
`gorm:"index;not null" json:"api_key_id"`
AccountID
int64
`gorm:"index;not null" json:"account_id"`
RequestID
string
`gorm:"size:64" json:"request_id"`
Model
string
`gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID
*
int64
`gorm:"index" json:"group_id"`
SubscriptionID
*
int64
`gorm:"index" json:"subscription_id"`
// Token使用量(4类)
InputTokens
int
`gorm:"default:0;not null" json:"input_tokens"`
OutputTokens
int
`gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens
int
`gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens
int
`gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens
int
`gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens
int
`gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用(USD)
InputCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"`
// 原始总费用
ActualCost
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"`
// 实际扣除费用
RateMultiplier
float64
`gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"`
// 计费倍率
// 元数据
BillingType
int8
`gorm:"type:smallint;default:0;not null" json:"billing_type"`
// 0=余额 1=订阅
Stream
bool
`gorm:"default:false;not null" json:"stream"`
DurationMs
*
int
`json:"duration_ms"`
FirstTokenMs
*
int
`json:"first_token_ms"`
// 首字时间(流式请求)
CreatedAt
time
.
Time
`gorm:"index;not null" json:"created_at"`
// 关联
User
*
User
`gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey
*
ApiKey
`gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account
*
Account
`gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription
*
UserSubscription
`gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func
(
UsageLog
)
TableName
()
string
{
return
"usage_logs"
}
// TotalTokens 总token数
func
(
u
*
UsageLog
)
TotalTokens
()
int
{
return
u
.
InputTokens
+
u
.
OutputTokens
+
u
.
CacheCreationTokens
+
u
.
CacheReadTokens
}
backend/internal/model/user.go
0 → 100644
View file @
642842c2
package
model
import
(
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type
User
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
Email
string
`gorm:"uniqueIndex;size:255;not null" json:"email"`
PasswordHash
string
`gorm:"size:255;not null" json:"-"`
Role
string
`gorm:"size:20;default:user;not null" json:"role"`
// admin/user
Balance
float64
`gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency
int
`gorm:"default:5;not null" json:"concurrency"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/disabled
AllowedGroups
pq
.
Int64Array
`gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
DeletedAt
gorm
.
DeletedAt
`gorm:"index" json:"-"`
// 关联
ApiKeys
[]
ApiKey
`gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
}
func
(
User
)
TableName
()
string
{
return
"users"
}
// IsAdmin 检查是否管理员
func
(
u
*
User
)
IsAdmin
()
bool
{
return
u
.
Role
==
"admin"
}
// IsActive 检查是否激活
func
(
u
*
User
)
IsActive
()
bool
{
return
u
.
Status
==
"active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func
(
u
*
User
)
CanBindGroup
(
groupID
int64
,
isExclusive
bool
)
bool
{
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if
len
(
u
.
AllowedGroups
)
>
0
{
for
_
,
id
:=
range
u
.
AllowedGroups
{
if
id
==
groupID
{
return
true
}
}
return
false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return
!
isExclusive
}
// SetPassword 设置密码(哈希存储)
func
(
u
*
User
)
SetPassword
(
password
string
)
error
{
hash
,
err
:=
bcrypt
.
GenerateFromPassword
([]
byte
(
password
),
bcrypt
.
DefaultCost
)
if
err
!=
nil
{
return
err
}
u
.
PasswordHash
=
string
(
hash
)
return
nil
}
// CheckPassword 验证密码
func
(
u
*
User
)
CheckPassword
(
password
string
)
bool
{
err
:=
bcrypt
.
CompareHashAndPassword
([]
byte
(
u
.
PasswordHash
),
[]
byte
(
password
))
return
err
==
nil
}
backend/internal/model/user_subscription.go
0 → 100644
View file @
642842c2
package
model
import
(
"time"
)
// 订阅状态常量
const
(
SubscriptionStatusActive
=
"active"
SubscriptionStatusExpired
=
"expired"
SubscriptionStatusSuspended
=
"suspended"
)
// UserSubscription 用户订阅模型
type
UserSubscription
struct
{
ID
int64
`gorm:"primaryKey" json:"id"`
UserID
int64
`gorm:"index;not null" json:"user_id"`
GroupID
int64
`gorm:"index;not null" json:"group_id"`
// 订阅有效期
StartsAt
time
.
Time
`gorm:"not null" json:"starts_at"`
ExpiresAt
time
.
Time
`gorm:"not null" json:"expires_at"`
Status
string
`gorm:"size:20;default:active;not null" json:"status"`
// active/expired/suspended
// 滑动窗口起始时间(nil = 未激活)
DailyWindowStart
*
time
.
Time
`json:"daily_window_start"`
WeeklyWindowStart
*
time
.
Time
`json:"weekly_window_start"`
MonthlyWindowStart
*
time
.
Time
`json:"monthly_window_start"`
// 当前窗口已用额度(USD,基于 total_cost 计算)
DailyUsageUSD
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
WeeklyUsageUSD
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
MonthlyUsageUSD
float64
`gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
// 管理员分配信息
AssignedBy
*
int64
`gorm:"index" json:"assigned_by"`
AssignedAt
time
.
Time
`gorm:"not null" json:"assigned_at"`
Notes
string
`gorm:"type:text" json:"notes"`
CreatedAt
time
.
Time
`gorm:"not null" json:"created_at"`
UpdatedAt
time
.
Time
`gorm:"not null" json:"updated_at"`
// 关联
User
*
User
`gorm:"foreignKey:UserID" json:"user,omitempty"`
Group
*
Group
`gorm:"foreignKey:GroupID" json:"group,omitempty"`
AssignedByUser
*
User
`gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
}
func
(
UserSubscription
)
TableName
()
string
{
return
"user_subscriptions"
}
// IsActive 检查订阅是否有效(状态为active且未过期)
func
(
s
*
UserSubscription
)
IsActive
()
bool
{
return
s
.
Status
==
SubscriptionStatusActive
&&
time
.
Now
()
.
Before
(
s
.
ExpiresAt
)
}
// IsExpired 检查订阅是否已过期
func
(
s
*
UserSubscription
)
IsExpired
()
bool
{
return
time
.
Now
()
.
After
(
s
.
ExpiresAt
)
}
// DaysRemaining 返回订阅剩余天数
func
(
s
*
UserSubscription
)
DaysRemaining
()
int
{
if
s
.
IsExpired
()
{
return
0
}
return
int
(
time
.
Until
(
s
.
ExpiresAt
)
.
Hours
()
/
24
)
}
// IsWindowActivated 检查窗口是否已激活
func
(
s
*
UserSubscription
)
IsWindowActivated
()
bool
{
return
s
.
DailyWindowStart
!=
nil
||
s
.
WeeklyWindowStart
!=
nil
||
s
.
MonthlyWindowStart
!=
nil
}
// NeedsDailyReset 检查日窗口是否需要重置
func
(
s
*
UserSubscription
)
NeedsDailyReset
()
bool
{
if
s
.
DailyWindowStart
==
nil
{
return
false
}
return
time
.
Since
(
*
s
.
DailyWindowStart
)
>=
24
*
time
.
Hour
}
// NeedsWeeklyReset 检查周窗口是否需要重置
func
(
s
*
UserSubscription
)
NeedsWeeklyReset
()
bool
{
if
s
.
WeeklyWindowStart
==
nil
{
return
false
}
return
time
.
Since
(
*
s
.
WeeklyWindowStart
)
>=
7
*
24
*
time
.
Hour
}
// NeedsMonthlyReset 检查月窗口是否需要重置
func
(
s
*
UserSubscription
)
NeedsMonthlyReset
()
bool
{
if
s
.
MonthlyWindowStart
==
nil
{
return
false
}
return
time
.
Since
(
*
s
.
MonthlyWindowStart
)
>=
30
*
24
*
time
.
Hour
}
// DailyResetTime 返回日窗口重置时间
func
(
s
*
UserSubscription
)
DailyResetTime
()
*
time
.
Time
{
if
s
.
DailyWindowStart
==
nil
{
return
nil
}
t
:=
s
.
DailyWindowStart
.
Add
(
24
*
time
.
Hour
)
return
&
t
}
// WeeklyResetTime 返回周窗口重置时间
func
(
s
*
UserSubscription
)
WeeklyResetTime
()
*
time
.
Time
{
if
s
.
WeeklyWindowStart
==
nil
{
return
nil
}
t
:=
s
.
WeeklyWindowStart
.
Add
(
7
*
24
*
time
.
Hour
)
return
&
t
}
// MonthlyResetTime 返回月窗口重置时间
func
(
s
*
UserSubscription
)
MonthlyResetTime
()
*
time
.
Time
{
if
s
.
MonthlyWindowStart
==
nil
{
return
nil
}
t
:=
s
.
MonthlyWindowStart
.
Add
(
30
*
24
*
time
.
Hour
)
return
&
t
}
// CheckDailyLimit 检查是否超出日限额
func
(
s
*
UserSubscription
)
CheckDailyLimit
(
group
*
Group
,
additionalCost
float64
)
bool
{
if
!
group
.
HasDailyLimit
()
{
return
true
// 无限制
}
return
s
.
DailyUsageUSD
+
additionalCost
<=
*
group
.
DailyLimitUSD
}
// CheckWeeklyLimit 检查是否超出周限额
func
(
s
*
UserSubscription
)
CheckWeeklyLimit
(
group
*
Group
,
additionalCost
float64
)
bool
{
if
!
group
.
HasWeeklyLimit
()
{
return
true
// 无限制
}
return
s
.
WeeklyUsageUSD
+
additionalCost
<=
*
group
.
WeeklyLimitUSD
}
// CheckMonthlyLimit 检查是否超出月限额
func
(
s
*
UserSubscription
)
CheckMonthlyLimit
(
group
*
Group
,
additionalCost
float64
)
bool
{
if
!
group
.
HasMonthlyLimit
()
{
return
true
// 无限制
}
return
s
.
MonthlyUsageUSD
+
additionalCost
<=
*
group
.
MonthlyLimitUSD
}
// CheckAllLimits 检查所有限额
func
(
s
*
UserSubscription
)
CheckAllLimits
(
group
*
Group
,
additionalCost
float64
)
(
daily
,
weekly
,
monthly
bool
)
{
daily
=
s
.
CheckDailyLimit
(
group
,
additionalCost
)
weekly
=
s
.
CheckWeeklyLimit
(
group
,
additionalCost
)
monthly
=
s
.
CheckMonthlyLimit
(
group
,
additionalCost
)
return
}
backend/internal/pkg/oauth/oauth.go
0 → 100644
View file @
642842c2
package
oauth
import
(
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// Claude OAuth Constants (from CRS project)
const
(
// OAuth Client ID for Claude
ClientID
=
"9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL
=
"https://claude.ai/oauth/authorize"
TokenURL
=
"https://console.anthropic.com/v1/oauth/token"
RedirectURI
=
"https://console.anthropic.com/oauth/code/callback"
// Scopes
ScopeProfile
=
"user:profile"
ScopeInference
=
"user:inference"
// Session TTL
SessionTTL
=
30
*
time
.
Minute
)
// OAuthSession stores OAuth flow state
type
OAuthSession
struct
{
State
string
`json:"state"`
CodeVerifier
string
`json:"code_verifier"`
Scope
string
`json:"scope"`
ProxyURL
string
`json:"proxy_url,omitempty"`
CreatedAt
time
.
Time
`json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type
SessionStore
struct
{
mu
sync
.
RWMutex
sessions
map
[
string
]
*
OAuthSession
}
// NewSessionStore creates a new session store
func
NewSessionStore
()
*
SessionStore
{
store
:=
&
SessionStore
{
sessions
:
make
(
map
[
string
]
*
OAuthSession
),
}
// Start cleanup goroutine
go
store
.
cleanup
()
return
store
}
// Set stores a session
func
(
s
*
SessionStore
)
Set
(
sessionID
string
,
session
*
OAuthSession
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
sessions
[
sessionID
]
=
session
}
// Get retrieves a session
func
(
s
*
SessionStore
)
Get
(
sessionID
string
)
(
*
OAuthSession
,
bool
)
{
s
.
mu
.
RLock
()
defer
s
.
mu
.
RUnlock
()
session
,
ok
:=
s
.
sessions
[
sessionID
]
if
!
ok
{
return
nil
,
false
}
// Check if expired
if
time
.
Since
(
session
.
CreatedAt
)
>
SessionTTL
{
return
nil
,
false
}
return
session
,
true
}
// Delete removes a session
func
(
s
*
SessionStore
)
Delete
(
sessionID
string
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
delete
(
s
.
sessions
,
sessionID
)
}
// cleanup removes expired sessions periodically
func
(
s
*
SessionStore
)
cleanup
()
{
ticker
:=
time
.
NewTicker
(
5
*
time
.
Minute
)
for
range
ticker
.
C
{
s
.
mu
.
Lock
()
for
id
,
session
:=
range
s
.
sessions
{
if
time
.
Since
(
session
.
CreatedAt
)
>
SessionTTL
{
delete
(
s
.
sessions
,
id
)
}
}
s
.
mu
.
Unlock
()
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func
GenerateRandomBytes
(
n
int
)
([]
byte
,
error
)
{
b
:=
make
([]
byte
,
n
)
_
,
err
:=
rand
.
Read
(
b
)
if
err
!=
nil
{
return
nil
,
err
}
return
b
,
nil
}
// GenerateState generates a random state string for OAuth
func
GenerateState
()
(
string
,
error
)
{
bytes
,
err
:=
GenerateRandomBytes
(
32
)
if
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
bytes
),
nil
}
// GenerateSessionID generates a unique session ID
func
GenerateSessionID
()
(
string
,
error
)
{
bytes
,
err
:=
GenerateRandomBytes
(
16
)
if
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
bytes
),
nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
func
GenerateCodeVerifier
()
(
string
,
error
)
{
bytes
,
err
:=
GenerateRandomBytes
(
32
)
if
err
!=
nil
{
return
""
,
err
}
return
base64URLEncode
(
bytes
),
nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
func
GenerateCodeChallenge
(
verifier
string
)
string
{
hash
:=
sha256
.
Sum256
([]
byte
(
verifier
))
return
base64URLEncode
(
hash
[
:
])
}
// base64URLEncode encodes bytes to base64url without padding
func
base64URLEncode
(
data
[]
byte
)
string
{
encoded
:=
base64
.
URLEncoding
.
EncodeToString
(
data
)
// Remove padding
return
strings
.
TrimRight
(
encoded
,
"="
)
}
// BuildAuthorizationURL builds the OAuth authorization URL
func
BuildAuthorizationURL
(
state
,
codeChallenge
,
scope
string
)
string
{
params
:=
url
.
Values
{}
params
.
Set
(
"response_type"
,
"code"
)
params
.
Set
(
"client_id"
,
ClientID
)
params
.
Set
(
"redirect_uri"
,
RedirectURI
)
params
.
Set
(
"scope"
,
scope
)
params
.
Set
(
"state"
,
state
)
params
.
Set
(
"code_challenge"
,
codeChallenge
)
params
.
Set
(
"code_challenge_method"
,
"S256"
)
return
fmt
.
Sprintf
(
"%s?%s"
,
AuthorizeURL
,
params
.
Encode
())
}
// TokenRequest represents the token exchange request body
type
TokenRequest
struct
{
GrantType
string
`json:"grant_type"`
ClientID
string
`json:"client_id"`
Code
string
`json:"code"`
RedirectURI
string
`json:"redirect_uri"`
CodeVerifier
string
`json:"code_verifier"`
State
string
`json:"state"`
}
// TokenResponse represents the token response from OAuth provider
type
TokenResponse
struct
{
AccessToken
string
`json:"access_token"`
TokenType
string
`json:"token_type"`
ExpiresIn
int64
`json:"expires_in"`
RefreshToken
string
`json:"refresh_token,omitempty"`
Scope
string
`json:"scope,omitempty"`
// Organization and Account info from OAuth response
Organization
*
OrgInfo
`json:"organization,omitempty"`
Account
*
AccountInfo
`json:"account,omitempty"`
}
// OrgInfo represents organization info from OAuth response
type
OrgInfo
struct
{
UUID
string
`json:"uuid"`
}
// AccountInfo represents account info from OAuth response
type
AccountInfo
struct
{
UUID
string
`json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type
RefreshTokenRequest
struct
{
GrantType
string
`json:"grant_type"`
RefreshToken
string
`json:"refresh_token"`
ClientID
string
`json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func
BuildTokenRequest
(
code
,
codeVerifier
,
state
string
)
*
TokenRequest
{
return
&
TokenRequest
{
GrantType
:
"authorization_code"
,
ClientID
:
ClientID
,
Code
:
code
,
RedirectURI
:
RedirectURI
,
CodeVerifier
:
codeVerifier
,
State
:
state
,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func
BuildRefreshTokenRequest
(
refreshToken
string
)
*
RefreshTokenRequest
{
return
&
RefreshTokenRequest
{
GrantType
:
"refresh_token"
,
RefreshToken
:
refreshToken
,
ClientID
:
ClientID
,
}
}
backend/internal/pkg/response/response.go
0 → 100644
View file @
642842c2
package
response
import
(
"math"
"net/http"
"github.com/gin-gonic/gin"
)
// Response 标准API响应格式
type
Response
struct
{
Code
int
`json:"code"`
Message
string
`json:"message"`
Data
interface
{}
`json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
type
PaginatedData
struct
{
Items
interface
{}
`json:"items"`
Total
int64
`json:"total"`
Page
int
`json:"page"`
PageSize
int
`json:"page_size"`
Pages
int
`json:"pages"`
}
// Success 返回成功响应
func
Success
(
c
*
gin
.
Context
,
data
interface
{})
{
c
.
JSON
(
http
.
StatusOK
,
Response
{
Code
:
0
,
Message
:
"success"
,
Data
:
data
,
})
}
// Created 返回创建成功响应
func
Created
(
c
*
gin
.
Context
,
data
interface
{})
{
c
.
JSON
(
http
.
StatusCreated
,
Response
{
Code
:
0
,
Message
:
"success"
,
Data
:
data
,
})
}
// Error 返回错误响应
func
Error
(
c
*
gin
.
Context
,
statusCode
int
,
message
string
)
{
c
.
JSON
(
statusCode
,
Response
{
Code
:
statusCode
,
Message
:
message
,
})
}
// BadRequest 返回400错误
func
BadRequest
(
c
*
gin
.
Context
,
message
string
)
{
Error
(
c
,
http
.
StatusBadRequest
,
message
)
}
// Unauthorized 返回401错误
func
Unauthorized
(
c
*
gin
.
Context
,
message
string
)
{
Error
(
c
,
http
.
StatusUnauthorized
,
message
)
}
// Forbidden 返回403错误
func
Forbidden
(
c
*
gin
.
Context
,
message
string
)
{
Error
(
c
,
http
.
StatusForbidden
,
message
)
}
// NotFound 返回404错误
func
NotFound
(
c
*
gin
.
Context
,
message
string
)
{
Error
(
c
,
http
.
StatusNotFound
,
message
)
}
// InternalError 返回500错误
func
InternalError
(
c
*
gin
.
Context
,
message
string
)
{
Error
(
c
,
http
.
StatusInternalServerError
,
message
)
}
// Paginated 返回分页数据
func
Paginated
(
c
*
gin
.
Context
,
items
interface
{},
total
int64
,
page
,
pageSize
int
)
{
pages
:=
int
(
math
.
Ceil
(
float64
(
total
)
/
float64
(
pageSize
)))
if
pages
<
1
{
pages
=
1
}
Success
(
c
,
PaginatedData
{
Items
:
items
,
Total
:
total
,
Page
:
page
,
PageSize
:
pageSize
,
Pages
:
pages
,
})
}
// PaginationResult 分页结果(与repository.PaginationResult兼容)
type
PaginationResult
struct
{
Total
int64
Page
int
PageSize
int
Pages
int
}
// PaginatedWithResult 使用PaginationResult返回分页数据
func
PaginatedWithResult
(
c
*
gin
.
Context
,
items
interface
{},
pagination
*
PaginationResult
)
{
if
pagination
==
nil
{
Success
(
c
,
PaginatedData
{
Items
:
items
,
Total
:
0
,
Page
:
1
,
PageSize
:
20
,
Pages
:
1
,
})
return
}
Success
(
c
,
PaginatedData
{
Items
:
items
,
Total
:
pagination
.
Total
,
Page
:
pagination
.
Page
,
PageSize
:
pagination
.
PageSize
,
Pages
:
pagination
.
Pages
,
})
}
// ParsePagination 解析分页参数
func
ParsePagination
(
c
*
gin
.
Context
)
(
page
,
pageSize
int
)
{
page
=
1
pageSize
=
20
if
p
:=
c
.
Query
(
"page"
);
p
!=
""
{
if
val
,
err
:=
parseInt
(
p
);
err
==
nil
&&
val
>
0
{
page
=
val
}
}
// 支持 page_size 和 limit 两种参数名
if
ps
:=
c
.
Query
(
"page_size"
);
ps
!=
""
{
if
val
,
err
:=
parseInt
(
ps
);
err
==
nil
&&
val
>
0
&&
val
<=
100
{
pageSize
=
val
}
}
else
if
l
:=
c
.
Query
(
"limit"
);
l
!=
""
{
if
val
,
err
:=
parseInt
(
l
);
err
==
nil
&&
val
>
0
&&
val
<=
100
{
pageSize
=
val
}
}
return
page
,
pageSize
}
func
parseInt
(
s
string
)
(
int
,
error
)
{
var
result
int
for
_
,
c
:=
range
s
{
if
c
<
'0'
||
c
>
'9'
{
return
0
,
nil
}
result
=
result
*
10
+
int
(
c
-
'0'
)
}
return
result
,
nil
}
backend/internal/pkg/timezone/timezone.go
0 → 100644
View file @
642842c2
// Package timezone provides global timezone management for the application.
// Similar to PHP's date_default_timezone_set, this package allows setting
// a global timezone that affects all time.Now() calls.
package
timezone
import
(
"fmt"
"log"
"time"
)
var
(
// location is the global timezone location
location
*
time
.
Location
// tzName stores the timezone name for logging/debugging
tzName
string
)
// Init initializes the global timezone setting.
// This should be called once at application startup.
// Example timezone values: "Asia/Shanghai", "America/New_York", "UTC"
func
Init
(
tz
string
)
error
{
if
tz
==
""
{
tz
=
"Asia/Shanghai"
// Default timezone
}
loc
,
err
:=
time
.
LoadLocation
(
tz
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"invalid timezone %q: %w"
,
tz
,
err
)
}
// Set the global Go time.Local to our timezone
// This affects time.Now() throughout the application
time
.
Local
=
loc
location
=
loc
tzName
=
tz
log
.
Printf
(
"Timezone initialized: %s (UTC offset: %s)"
,
tz
,
getUTCOffset
(
loc
))
return
nil
}
// getUTCOffset returns the current UTC offset for a location
func
getUTCOffset
(
loc
*
time
.
Location
)
string
{
_
,
offset
:=
time
.
Now
()
.
In
(
loc
)
.
Zone
()
hours
:=
offset
/
3600
minutes
:=
(
offset
%
3600
)
/
60
if
minutes
<
0
{
minutes
=
-
minutes
}
sign
:=
"+"
if
hours
<
0
{
sign
=
"-"
hours
=
-
hours
}
return
fmt
.
Sprintf
(
"%s%02d:%02d"
,
sign
,
hours
,
minutes
)
}
// Now returns the current time in the configured timezone.
// This is equivalent to time.Now() after Init() is called,
// but provided for explicit timezone-aware code.
func
Now
()
time
.
Time
{
if
location
==
nil
{
return
time
.
Now
()
}
return
time
.
Now
()
.
In
(
location
)
}
// Location returns the configured timezone location.
func
Location
()
*
time
.
Location
{
if
location
==
nil
{
return
time
.
Local
}
return
location
}
// Name returns the configured timezone name.
func
Name
()
string
{
if
tzName
==
""
{
return
"Local"
}
return
tzName
}
// StartOfDay returns the start of the given day (00:00:00) in the configured timezone.
func
StartOfDay
(
t
time
.
Time
)
time
.
Time
{
loc
:=
Location
()
t
=
t
.
In
(
loc
)
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
0
,
0
,
0
,
0
,
loc
)
}
// Today returns the start of today (00:00:00) in the configured timezone.
func
Today
()
time
.
Time
{
return
StartOfDay
(
Now
())
}
// EndOfDay returns the end of the given day (23:59:59.999999999) in the configured timezone.
func
EndOfDay
(
t
time
.
Time
)
time
.
Time
{
loc
:=
Location
()
t
=
t
.
In
(
loc
)
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
(),
23
,
59
,
59
,
999999999
,
loc
)
}
// StartOfWeek returns the start of the week (Monday 00:00:00) for the given time.
func
StartOfWeek
(
t
time
.
Time
)
time
.
Time
{
loc
:=
Location
()
t
=
t
.
In
(
loc
)
weekday
:=
int
(
t
.
Weekday
())
if
weekday
==
0
{
weekday
=
7
// Sunday is day 7
}
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
t
.
Day
()
-
weekday
+
1
,
0
,
0
,
0
,
0
,
loc
)
}
// StartOfMonth returns the start of the month (1st day 00:00:00) for the given time.
func
StartOfMonth
(
t
time
.
Time
)
time
.
Time
{
loc
:=
Location
()
t
=
t
.
In
(
loc
)
return
time
.
Date
(
t
.
Year
(),
t
.
Month
(),
1
,
0
,
0
,
0
,
0
,
loc
)
}
// ParseInLocation parses a time string in the configured timezone.
func
ParseInLocation
(
layout
,
value
string
)
(
time
.
Time
,
error
)
{
return
time
.
ParseInLocation
(
layout
,
value
,
Location
())
}
backend/internal/pkg/timezone/timezone_test.go
0 → 100644
View file @
642842c2
package
timezone
import
(
"testing"
"time"
)
func
TestInit
(
t
*
testing
.
T
)
{
// Test with valid timezone
err
:=
Init
(
"Asia/Shanghai"
)
if
err
!=
nil
{
t
.
Fatalf
(
"Init failed with valid timezone: %v"
,
err
)
}
// Verify time.Local was set
if
time
.
Local
.
String
()
!=
"Asia/Shanghai"
{
t
.
Errorf
(
"time.Local not set correctly, got %s"
,
time
.
Local
.
String
())
}
// Verify our location variable
if
Location
()
.
String
()
!=
"Asia/Shanghai"
{
t
.
Errorf
(
"Location() not set correctly, got %s"
,
Location
()
.
String
())
}
// Test Name()
if
Name
()
!=
"Asia/Shanghai"
{
t
.
Errorf
(
"Name() not set correctly, got %s"
,
Name
())
}
}
func
TestInitInvalidTimezone
(
t
*
testing
.
T
)
{
err
:=
Init
(
"Invalid/Timezone"
)
if
err
==
nil
{
t
.
Error
(
"Init should fail with invalid timezone"
)
}
}
func
TestTimeNowAffected
(
t
*
testing
.
T
)
{
// Reset to UTC first
Init
(
"UTC"
)
utcNow
:=
time
.
Now
()
// Switch to Shanghai (UTC+8)
Init
(
"Asia/Shanghai"
)
shanghaiNow
:=
time
.
Now
()
// The times should be the same instant, but different timezone representation
// Shanghai should be 8 hours ahead in display
_
,
utcOffset
:=
utcNow
.
Zone
()
_
,
shanghaiOffset
:=
shanghaiNow
.
Zone
()
expectedDiff
:=
8
*
3600
// 8 hours in seconds
actualDiff
:=
shanghaiOffset
-
utcOffset
if
actualDiff
!=
expectedDiff
{
t
.
Errorf
(
"Timezone offset difference incorrect: expected %d, got %d"
,
expectedDiff
,
actualDiff
)
}
}
func
TestToday
(
t
*
testing
.
T
)
{
Init
(
"Asia/Shanghai"
)
today
:=
Today
()
now
:=
Now
()
// Today should be at 00:00:00
if
today
.
Hour
()
!=
0
||
today
.
Minute
()
!=
0
||
today
.
Second
()
!=
0
{
t
.
Errorf
(
"Today() not at start of day: %v"
,
today
)
}
// Today should be same date as now
if
today
.
Year
()
!=
now
.
Year
()
||
today
.
Month
()
!=
now
.
Month
()
||
today
.
Day
()
!=
now
.
Day
()
{
t
.
Errorf
(
"Today() date mismatch: today=%v, now=%v"
,
today
,
now
)
}
}
func
TestStartOfDay
(
t
*
testing
.
T
)
{
Init
(
"Asia/Shanghai"
)
// Create a time at 15:30:45
testTime
:=
time
.
Date
(
2024
,
6
,
15
,
15
,
30
,
45
,
123456789
,
Location
())
startOfDay
:=
StartOfDay
(
testTime
)
expected
:=
time
.
Date
(
2024
,
6
,
15
,
0
,
0
,
0
,
0
,
Location
())
if
!
startOfDay
.
Equal
(
expected
)
{
t
.
Errorf
(
"StartOfDay incorrect: expected %v, got %v"
,
expected
,
startOfDay
)
}
}
func
TestTruncateVsStartOfDay
(
t
*
testing
.
T
)
{
// This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code
Init
(
"Asia/Shanghai"
)
now
:=
Now
()
// Truncate operates on UTC, not local time
truncated
:=
now
.
Truncate
(
24
*
time
.
Hour
)
// StartOfDay operates on local time
startOfDay
:=
StartOfDay
(
now
)
// These will likely be different for non-UTC timezones
t
.
Logf
(
"Now: %v"
,
now
)
t
.
Logf
(
"Truncate(24h): %v"
,
truncated
)
t
.
Logf
(
"StartOfDay: %v"
,
startOfDay
)
// The truncated time may not be at local midnight
// StartOfDay is always at local midnight
if
startOfDay
.
Hour
()
!=
0
{
t
.
Errorf
(
"StartOfDay should be at hour 0, got %d"
,
startOfDay
.
Hour
())
}
}
func
TestDSTAwareness
(
t
*
testing
.
T
)
{
// Test with a timezone that has DST (America/New_York)
err
:=
Init
(
"America/New_York"
)
if
err
!=
nil
{
t
.
Skipf
(
"America/New_York timezone not available: %v"
,
err
)
}
// Just verify it doesn't crash
_
=
Today
()
_
=
Now
()
_
=
StartOfDay
(
Now
())
}
backend/internal/repository/account_repo.go
0 → 100644
View file @
642842c2
package
repository
import
(
"context"
"sub2api/internal/model"
"time"
"gorm.io/gorm"
)
type
AccountRepository
struct
{
db
*
gorm
.
DB
}
func
NewAccountRepository
(
db
*
gorm
.
DB
)
*
AccountRepository
{
return
&
AccountRepository
{
db
:
db
}
}
func
(
r
*
AccountRepository
)
Create
(
ctx
context
.
Context
,
account
*
model
.
Account
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
account
)
.
Error
}
func
(
r
*
AccountRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
{
var
account
model
.
Account
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"Proxy"
)
.
Preload
(
"AccountGroups"
)
.
First
(
&
account
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
// 填充 GroupIDs 虚拟字段
account
.
GroupIDs
=
make
([]
int64
,
0
,
len
(
account
.
AccountGroups
))
for
_
,
ag
:=
range
account
.
AccountGroups
{
account
.
GroupIDs
=
append
(
account
.
GroupIDs
,
ag
.
GroupID
)
}
return
&
account
,
nil
}
func
(
r
*
AccountRepository
)
Update
(
ctx
context
.
Context
,
account
*
model
.
Account
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
account
)
.
Error
}
func
(
r
*
AccountRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
// 先删除账号与分组的绑定关系
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ?"
,
id
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
;
err
!=
nil
{
return
err
}
// 再删除账号
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
Account
{},
id
)
.
Error
}
func
(
r
*
AccountRepository
)
List
(
ctx
context
.
Context
,
params
PaginationParams
)
([]
model
.
Account
,
*
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
,
""
)
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func
(
r
*
AccountRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
model
.
Account
,
*
PaginationResult
,
error
)
{
var
accounts
[]
model
.
Account
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
// Apply filters
if
platform
!=
""
{
db
=
db
.
Where
(
"platform = ?"
,
platform
)
}
if
accountType
!=
""
{
db
=
db
.
Where
(
"type = ?"
,
accountType
)
}
if
status
!=
""
{
db
=
db
.
Where
(
"status = ?"
,
status
)
}
if
search
!=
""
{
searchPattern
:=
"%"
+
search
+
"%"
db
=
db
.
Where
(
"name ILIKE ?"
,
searchPattern
)
}
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Preload
(
"Proxy"
)
.
Preload
(
"AccountGroups"
)
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
accounts
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
// 填充每个 Account 的 GroupIDs 虚拟字段
for
i
:=
range
accounts
{
accounts
[
i
]
.
GroupIDs
=
make
([]
int64
,
0
,
len
(
accounts
[
i
]
.
AccountGroups
))
for
_
,
ag
:=
range
accounts
[
i
]
.
AccountGroups
{
accounts
[
i
]
.
GroupIDs
=
append
(
accounts
[
i
]
.
GroupIDs
,
ag
.
GroupID
)
}
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
accounts
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
func
(
r
*
AccountRepository
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Where
(
"account_groups.group_id = ? AND accounts.status = ?"
,
groupID
,
model
.
StatusActive
)
.
Preload
(
"Proxy"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Find
(
&
accounts
)
.
Error
return
accounts
,
err
}
func
(
r
*
AccountRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Preload
(
"Proxy"
)
.
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
return
accounts
,
err
}
func
(
r
*
AccountRepository
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
now
:=
time
.
Now
()
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"last_used_at"
,
now
)
.
Error
}
func
(
r
*
AccountRepository
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"status"
:
model
.
StatusError
,
"error_message"
:
errorMsg
,
})
.
Error
}
func
(
r
*
AccountRepository
)
AddToGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
,
priority
int
)
error
{
ag
:=
&
model
.
AccountGroup
{
AccountID
:
accountID
,
GroupID
:
groupID
,
Priority
:
priority
,
}
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
ag
)
.
Error
}
func
(
r
*
AccountRepository
)
RemoveFromGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ? AND group_id = ?"
,
accountID
,
groupID
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
}
func
(
r
*
AccountRepository
)
GetGroups
(
ctx
context
.
Context
,
accountID
int64
)
([]
model
.
Group
,
error
)
{
var
groups
[]
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.group_id = groups.id"
)
.
Where
(
"account_groups.account_id = ?"
,
accountID
)
.
Find
(
&
groups
)
.
Error
return
groups
,
err
}
func
(
r
*
AccountRepository
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"platform = ? AND status = ?"
,
platform
,
model
.
StatusActive
)
.
Preload
(
"Proxy"
)
.
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
return
accounts
,
err
}
func
(
r
*
AccountRepository
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
// 删除现有绑定
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ?"
,
accountID
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
;
err
!=
nil
{
return
err
}
// 添加新绑定
if
len
(
groupIDs
)
>
0
{
accountGroups
:=
make
([]
model
.
AccountGroup
,
0
,
len
(
groupIDs
))
for
i
,
groupID
:=
range
groupIDs
{
accountGroups
=
append
(
accountGroups
,
model
.
AccountGroup
{
AccountID
:
accountID
,
GroupID
:
groupID
,
Priority
:
i
+
1
,
// 使用索引作为优先级
})
}
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
&
accountGroups
)
.
Error
}
return
nil
}
// ListSchedulable 获取所有可调度的账号
func
(
r
*
AccountRepository
)
ListSchedulable
(
ctx
context
.
Context
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ? AND schedulable = ?"
,
model
.
StatusActive
,
true
)
.
Where
(
"(overload_until IS NULL OR overload_until <= ?)"
,
now
)
.
Where
(
"(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Order
(
"priority ASC"
)
.
Find
(
&
accounts
)
.
Error
return
accounts
,
err
}
// ListSchedulableByGroupID 按组获取可调度的账号
func
(
r
*
AccountRepository
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
Where
(
"account_groups.group_id = ?"
,
groupID
)
.
Where
(
"accounts.status = ? AND accounts.schedulable = ?"
,
model
.
StatusActive
,
true
)
.
Where
(
"(accounts.overload_until IS NULL OR accounts.overload_until <= ?)"
,
now
)
.
Where
(
"(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)"
,
now
)
.
Preload
(
"Proxy"
)
.
Order
(
"account_groups.priority ASC, accounts.priority ASC"
)
.
Find
(
&
accounts
)
.
Error
return
accounts
,
err
}
// SetRateLimited 标记账号为限流状态(429)
func
(
r
*
AccountRepository
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
now
:=
time
.
Now
()
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"rate_limited_at"
:
now
,
"rate_limit_reset_at"
:
resetAt
,
})
.
Error
}
// SetOverloaded 标记账号为过载状态(529)
func
(
r
*
AccountRepository
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"overload_until"
,
until
)
.
Error
}
// ClearRateLimit 清除账号的限流状态
func
(
r
*
AccountRepository
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"rate_limited_at"
:
nil
,
"rate_limit_reset_at"
:
nil
,
"overload_until"
:
nil
,
})
.
Error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func
(
r
*
AccountRepository
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
updates
:=
map
[
string
]
interface
{}{
"session_window_status"
:
status
,
}
if
start
!=
nil
{
updates
[
"session_window_start"
]
=
start
}
if
end
!=
nil
{
updates
[
"session_window_end"
]
=
end
}
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
updates
)
.
Error
}
// SetSchedulable 设置账号的调度开关
func
(
r
*
AccountRepository
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"schedulable"
,
schedulable
)
.
Error
}
backend/internal/repository/api_key_repo.go
0 → 100644
View file @
642842c2
package
repository
import
(
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type
ApiKeyRepository
struct
{
db
*
gorm
.
DB
}
func
NewApiKeyRepository
(
db
*
gorm
.
DB
)
*
ApiKeyRepository
{
return
&
ApiKeyRepository
{
db
:
db
}
}
func
(
r
*
ApiKeyRepository
)
Create
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
key
)
.
Error
}
func
(
r
*
ApiKeyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
ApiKey
,
error
)
{
var
key
model
.
ApiKey
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
First
(
&
key
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
key
,
nil
}
func
(
r
*
ApiKeyRepository
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
ApiKey
,
error
)
{
var
apiKey
model
.
ApiKey
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Where
(
"key = ?"
,
key
)
.
First
(
&
apiKey
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
apiKey
,
nil
}
func
(
r
*
ApiKeyRepository
)
Update
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
key
)
.
Select
(
"name"
,
"group_id"
,
"status"
,
"updated_at"
)
.
Updates
(
key
)
.
Error
}
func
(
r
*
ApiKeyRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
ApiKey
{},
id
)
.
Error
}
func
(
r
*
ApiKeyRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
PaginationParams
)
([]
model
.
ApiKey
,
*
PaginationResult
,
error
)
{
var
keys
[]
model
.
ApiKey
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"user_id = ?"
,
userID
)
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Preload
(
"Group"
)
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
keys
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
keys
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
func
(
r
*
ApiKeyRepository
)
CountByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"user_id = ?"
,
userID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
func
(
r
*
ApiKeyRepository
)
ExistsByKey
(
ctx
context
.
Context
,
key
string
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"key = ?"
,
key
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
}
func
(
r
*
ApiKeyRepository
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
PaginationParams
)
([]
model
.
ApiKey
,
*
PaginationResult
,
error
)
{
var
keys
[]
model
.
ApiKey
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
groupID
)
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Preload
(
"User"
)
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
keys
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
keys
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func
(
r
*
ApiKeyRepository
)
SearchApiKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
model
.
ApiKey
,
error
)
{
var
keys
[]
model
.
ApiKey
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
if
userID
>
0
{
db
=
db
.
Where
(
"user_id = ?"
,
userID
)
}
if
keyword
!=
""
{
searchPattern
:=
"%"
+
keyword
+
"%"
db
=
db
.
Where
(
"name ILIKE ?"
,
searchPattern
)
}
if
err
:=
db
.
Limit
(
limit
)
.
Order
(
"id DESC"
)
.
Find
(
&
keys
)
.
Error
;
err
!=
nil
{
return
nil
,
err
}
return
keys
,
nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func
(
r
*
ApiKeyRepository
)
ClearGroupIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Update
(
"group_id"
,
nil
)
return
result
.
RowsAffected
,
result
.
Error
}
// CountByGroupID 获取分组的 API Key 数量
func
(
r
*
ApiKeyRepository
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
backend/internal/repository/group_repo.go
0 → 100644
View file @
642842c2
package
repository
import
(
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type
GroupRepository
struct
{
db
*
gorm
.
DB
}
func
NewGroupRepository
(
db
*
gorm
.
DB
)
*
GroupRepository
{
return
&
GroupRepository
{
db
:
db
}
}
func
(
r
*
GroupRepository
)
Create
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
group
)
.
Error
}
func
(
r
*
GroupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Group
,
error
)
{
var
group
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
group
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
group
,
nil
}
func
(
r
*
GroupRepository
)
Update
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
group
)
.
Error
}
func
(
r
*
GroupRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
Group
{},
id
)
.
Error
}
func
(
r
*
GroupRepository
)
List
(
ctx
context
.
Context
,
params
PaginationParams
)
([]
model
.
Group
,
*
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
nil
)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func
(
r
*
GroupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
model
.
Group
,
*
PaginationResult
,
error
)
{
var
groups
[]
model
.
Group
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Group
{})
// Apply filters
if
platform
!=
""
{
db
=
db
.
Where
(
"platform = ?"
,
platform
)
}
if
status
!=
""
{
db
=
db
.
Where
(
"status = ?"
,
status
)
}
if
isExclusive
!=
nil
{
db
=
db
.
Where
(
"is_exclusive = ?"
,
*
isExclusive
)
}
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
// 获取每个分组的账号数量
for
i
:=
range
groups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
groups
[
i
]
.
ID
)
groups
[
i
]
.
AccountCount
=
count
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
groups
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
func
(
r
*
GroupRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Group
,
error
)
{
var
groups
[]
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
// 获取每个分组的账号数量
for
i
:=
range
groups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
groups
[
i
]
.
ID
)
groups
[
i
]
.
AccountCount
=
count
}
return
groups
,
nil
}
func
(
r
*
GroupRepository
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Group
,
error
)
{
var
groups
[]
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ? AND platform = ?"
,
model
.
StatusActive
,
platform
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
// 获取每个分组的账号数量
for
i
:=
range
groups
{
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
groups
[
i
]
.
ID
)
groups
[
i
]
.
AccountCount
=
count
}
return
groups
,
nil
}
func
(
r
*
GroupRepository
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Group
{})
.
Where
(
"name = ?"
,
name
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
}
func
(
r
*
GroupRepository
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
AccountGroup
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func
(
r
*
GroupRepository
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Delete
(
&
model
.
AccountGroup
{})
return
result
.
RowsAffected
,
result
.
Error
}
// DB 返回底层数据库连接,用于事务处理
func
(
r
*
GroupRepository
)
DB
()
*
gorm
.
DB
{
return
r
.
db
}
backend/internal/repository/proxy_repo.go
0 → 100644
View file @
642842c2
package
repository
import
(
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type
ProxyRepository
struct
{
db
*
gorm
.
DB
}
func
NewProxyRepository
(
db
*
gorm
.
DB
)
*
ProxyRepository
{
return
&
ProxyRepository
{
db
:
db
}
}
func
(
r
*
ProxyRepository
)
Create
(
ctx
context
.
Context
,
proxy
*
model
.
Proxy
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
proxy
)
.
Error
}
func
(
r
*
ProxyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Proxy
,
error
)
{
var
proxy
model
.
Proxy
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
proxy
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
proxy
,
nil
}
func
(
r
*
ProxyRepository
)
Update
(
ctx
context
.
Context
,
proxy
*
model
.
Proxy
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
proxy
)
.
Error
}
func
(
r
*
ProxyRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
Proxy
{},
id
)
.
Error
}
func
(
r
*
ProxyRepository
)
List
(
ctx
context
.
Context
,
params
PaginationParams
)
([]
model
.
Proxy
,
*
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
)
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func
(
r
*
ProxyRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
PaginationParams
,
protocol
,
status
,
search
string
)
([]
model
.
Proxy
,
*
PaginationResult
,
error
)
{
var
proxies
[]
model
.
Proxy
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Proxy
{})
// Apply filters
if
protocol
!=
""
{
db
=
db
.
Where
(
"protocol = ?"
,
protocol
)
}
if
status
!=
""
{
db
=
db
.
Where
(
"status = ?"
,
status
)
}
if
search
!=
""
{
searchPattern
:=
"%"
+
search
+
"%"
db
=
db
.
Where
(
"name ILIKE ?"
,
searchPattern
)
}
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
proxies
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
proxies
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
func
(
r
*
ProxyRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Proxy
,
error
)
{
var
proxies
[]
model
.
Proxy
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Find
(
&
proxies
)
.
Error
return
proxies
,
err
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func
(
r
*
ProxyRepository
)
ExistsByHostPortAuth
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Proxy
{})
.
Where
(
"host = ? AND port = ? AND username = ? AND password = ?"
,
host
,
port
,
username
,
password
)
.
Count
(
&
count
)
.
Error
if
err
!=
nil
{
return
false
,
err
}
return
count
>
0
,
nil
}
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func
(
r
*
ProxyRepository
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"proxy_id = ?"
,
proxyID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func
(
r
*
ProxyRepository
)
GetAccountCountsForProxies
(
ctx
context
.
Context
)
(
map
[
int64
]
int64
,
error
)
{
type
result
struct
{
ProxyID
int64
`gorm:"column:proxy_id"`
Count
int64
`gorm:"column:count"`
}
var
results
[]
result
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Select
(
"proxy_id, COUNT(*) as count"
)
.
Where
(
"proxy_id IS NOT NULL"
)
.
Group
(
"proxy_id"
)
.
Scan
(
&
results
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
counts
:=
make
(
map
[
int64
]
int64
)
for
_
,
r
:=
range
results
{
counts
[
r
.
ProxyID
]
=
r
.
Count
}
return
counts
,
nil
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func
(
r
*
ProxyRepository
)
ListActiveWithAccountCount
(
ctx
context
.
Context
)
([]
model
.
ProxyWithAccountCount
,
error
)
{
var
proxies
[]
model
.
Proxy
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Order
(
"created_at DESC"
)
.
Find
(
&
proxies
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
// Get account counts
counts
,
err
:=
r
.
GetAccountCountsForProxies
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
// Build result with account counts
result
:=
make
([]
model
.
ProxyWithAccountCount
,
len
(
proxies
))
for
i
,
proxy
:=
range
proxies
{
result
[
i
]
=
model
.
ProxyWithAccountCount
{
Proxy
:
proxy
,
AccountCount
:
counts
[
proxy
.
ID
],
}
}
return
result
,
nil
}
backend/internal/repository/redeem_code_repo.go
0 → 100644
View file @
642842c2
package
repository
import
(
"context"
"sub2api/internal/model"
"time"
"gorm.io/gorm"
)
type
RedeemCodeRepository
struct
{
db
*
gorm
.
DB
}
func
NewRedeemCodeRepository
(
db
*
gorm
.
DB
)
*
RedeemCodeRepository
{
return
&
RedeemCodeRepository
{
db
:
db
}
}
func
(
r
*
RedeemCodeRepository
)
Create
(
ctx
context
.
Context
,
code
*
model
.
RedeemCode
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
code
)
.
Error
}
func
(
r
*
RedeemCodeRepository
)
CreateBatch
(
ctx
context
.
Context
,
codes
[]
model
.
RedeemCode
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
&
codes
)
.
Error
}
func
(
r
*
RedeemCodeRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
RedeemCode
,
error
)
{
var
code
model
.
RedeemCode
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
code
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
code
,
nil
}
func
(
r
*
RedeemCodeRepository
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
model
.
RedeemCode
,
error
)
{
var
redeemCode
model
.
RedeemCode
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"code = ?"
,
code
)
.
First
(
&
redeemCode
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
redeemCode
,
nil
}
func
(
r
*
RedeemCodeRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
RedeemCode
{},
id
)
.
Error
}
func
(
r
*
RedeemCodeRepository
)
List
(
ctx
context
.
Context
,
params
PaginationParams
)
([]
model
.
RedeemCode
,
*
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
)
}
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func
(
r
*
RedeemCodeRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
PaginationParams
,
codeType
,
status
,
search
string
)
([]
model
.
RedeemCode
,
*
PaginationResult
,
error
)
{
var
codes
[]
model
.
RedeemCode
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
RedeemCode
{})
// Apply filters
if
codeType
!=
""
{
db
=
db
.
Where
(
"type = ?"
,
codeType
)
}
if
status
!=
""
{
db
=
db
.
Where
(
"status = ?"
,
status
)
}
if
search
!=
""
{
searchPattern
:=
"%"
+
search
+
"%"
db
=
db
.
Where
(
"code ILIKE ?"
,
searchPattern
)
}
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
codes
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
codes
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
func
(
r
*
RedeemCodeRepository
)
Update
(
ctx
context
.
Context
,
code
*
model
.
RedeemCode
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
code
)
.
Error
}
func
(
r
*
RedeemCodeRepository
)
Use
(
ctx
context
.
Context
,
id
,
userID
int64
)
error
{
now
:=
time
.
Now
()
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
RedeemCode
{})
.
Where
(
"id = ? AND status = ?"
,
id
,
model
.
StatusUnused
)
.
Updates
(
map
[
string
]
interface
{}{
"status"
:
model
.
StatusUsed
,
"used_by"
:
userID
,
"used_at"
:
now
,
})
if
result
.
Error
!=
nil
{
return
result
.
Error
}
if
result
.
RowsAffected
==
0
{
return
gorm
.
ErrRecordNotFound
// 兑换码不存在或已被使用
}
return
nil
}
// ListByUser returns all redeem codes used by a specific user
func
(
r
*
RedeemCodeRepository
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
model
.
RedeemCode
,
error
)
{
var
codes
[]
model
.
RedeemCode
if
limit
<=
0
{
limit
=
10
}
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"Group"
)
.
Where
(
"used_by = ?"
,
userID
)
.
Order
(
"used_at DESC"
)
.
Limit
(
limit
)
.
Find
(
&
codes
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
codes
,
nil
}
backend/internal/repository/repository.go
0 → 100644
View file @
642842c2
package
repository
import
(
"gorm.io/gorm"
)
// Repositories 所有仓库的集合
type
Repositories
struct
{
User
*
UserRepository
ApiKey
*
ApiKeyRepository
Group
*
GroupRepository
Account
*
AccountRepository
Proxy
*
ProxyRepository
RedeemCode
*
RedeemCodeRepository
UsageLog
*
UsageLogRepository
Setting
*
SettingRepository
UserSubscription
*
UserSubscriptionRepository
}
// NewRepositories 创建所有仓库
func
NewRepositories
(
db
*
gorm
.
DB
)
*
Repositories
{
return
&
Repositories
{
User
:
NewUserRepository
(
db
),
ApiKey
:
NewApiKeyRepository
(
db
),
Group
:
NewGroupRepository
(
db
),
Account
:
NewAccountRepository
(
db
),
Proxy
:
NewProxyRepository
(
db
),
RedeemCode
:
NewRedeemCodeRepository
(
db
),
UsageLog
:
NewUsageLogRepository
(
db
),
Setting
:
NewSettingRepository
(
db
),
UserSubscription
:
NewUserSubscriptionRepository
(
db
),
}
}
// PaginationParams 分页参数
type
PaginationParams
struct
{
Page
int
PageSize
int
}
// PaginationResult 分页结果
type
PaginationResult
struct
{
Total
int64
Page
int
PageSize
int
Pages
int
}
// DefaultPagination 默认分页参数
func
DefaultPagination
()
PaginationParams
{
return
PaginationParams
{
Page
:
1
,
PageSize
:
20
,
}
}
// Offset 计算偏移量
func
(
p
PaginationParams
)
Offset
()
int
{
if
p
.
Page
<
1
{
p
.
Page
=
1
}
return
(
p
.
Page
-
1
)
*
p
.
PageSize
}
// Limit 获取限制数
func
(
p
PaginationParams
)
Limit
()
int
{
if
p
.
PageSize
<
1
{
return
20
}
if
p
.
PageSize
>
100
{
return
100
}
return
p
.
PageSize
}
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment