Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
642842c2
Commit
642842c2
authored
Dec 18, 2025
by
shaw
Browse files
First commit
parent
569f4882
Changes
201
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
201 of 201+
files are displayed.
Plain diff
Email patch
backend/internal/repository/setting_repo.go
0 → 100644
View file @
642842c2
package
repository
import
(
"context"
"sub2api/internal/model"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// SettingRepository 系统设置数据访问层
type
SettingRepository
struct
{
db
*
gorm
.
DB
}
// NewSettingRepository 创建系统设置仓库实例
func
NewSettingRepository
(
db
*
gorm
.
DB
)
*
SettingRepository
{
return
&
SettingRepository
{
db
:
db
}
}
// Get 根据Key获取设置值
func
(
r
*
SettingRepository
)
Get
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
Setting
,
error
)
{
var
setting
model
.
Setting
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"key = ?"
,
key
)
.
First
(
&
setting
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
setting
,
nil
}
// GetValue 获取设置值字符串
func
(
r
*
SettingRepository
)
GetValue
(
ctx
context
.
Context
,
key
string
)
(
string
,
error
)
{
setting
,
err
:=
r
.
Get
(
ctx
,
key
)
if
err
!=
nil
{
return
""
,
err
}
return
setting
.
Value
,
nil
}
// Set 设置值(存在则更新,不存在则创建)
func
(
r
*
SettingRepository
)
Set
(
ctx
context
.
Context
,
key
,
value
string
)
error
{
setting
:=
&
model
.
Setting
{
Key
:
key
,
Value
:
value
,
UpdatedAt
:
time
.
Now
(),
}
return
r
.
db
.
WithContext
(
ctx
)
.
Clauses
(
clause
.
OnConflict
{
Columns
:
[]
clause
.
Column
{{
Name
:
"key"
}},
DoUpdates
:
clause
.
AssignmentColumns
([]
string
{
"value"
,
"updated_at"
}),
})
.
Create
(
setting
)
.
Error
}
// GetMultiple 批量获取设置
func
(
r
*
SettingRepository
)
GetMultiple
(
ctx
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
var
settings
[]
model
.
Setting
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"key IN ?"
,
keys
)
.
Find
(
&
settings
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
result
:=
make
(
map
[
string
]
string
)
for
_
,
s
:=
range
settings
{
result
[
s
.
Key
]
=
s
.
Value
}
return
result
,
nil
}
// SetMultiple 批量设置值
func
(
r
*
SettingRepository
)
SetMultiple
(
ctx
context
.
Context
,
settings
map
[
string
]
string
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Transaction
(
func
(
tx
*
gorm
.
DB
)
error
{
for
key
,
value
:=
range
settings
{
setting
:=
&
model
.
Setting
{
Key
:
key
,
Value
:
value
,
UpdatedAt
:
time
.
Now
(),
}
if
err
:=
tx
.
Clauses
(
clause
.
OnConflict
{
Columns
:
[]
clause
.
Column
{{
Name
:
"key"
}},
DoUpdates
:
clause
.
AssignmentColumns
([]
string
{
"value"
,
"updated_at"
}),
})
.
Create
(
setting
)
.
Error
;
err
!=
nil
{
return
err
}
}
return
nil
})
}
// GetAll 获取所有设置
func
(
r
*
SettingRepository
)
GetAll
(
ctx
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
var
settings
[]
model
.
Setting
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Find
(
&
settings
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
result
:=
make
(
map
[
string
]
string
)
for
_
,
s
:=
range
settings
{
result
[
s
.
Key
]
=
s
.
Value
}
return
result
,
nil
}
// Delete 删除设置
func
(
r
*
SettingRepository
)
Delete
(
ctx
context
.
Context
,
key
string
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"key = ?"
,
key
)
.
Delete
(
&
model
.
Setting
{})
.
Error
}
backend/internal/repository/usage_log_repo.go
0 → 100644
View file @
642842c2
package
repository
import
(
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"time"
"gorm.io/gorm"
)
type
UsageLogRepository
struct
{
db
*
gorm
.
DB
}
func
NewUsageLogRepository
(
db
*
gorm
.
DB
)
*
UsageLogRepository
{
return
&
UsageLogRepository
{
db
:
db
}
}
func
(
r
*
UsageLogRepository
)
Create
(
ctx
context
.
Context
,
log
*
model
.
UsageLog
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
log
)
.
Error
}
func
(
r
*
UsageLogRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
UsageLog
,
error
)
{
var
log
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
log
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
log
,
nil
}
func
(
r
*
UsageLogRepository
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
params
PaginationParams
)
([]
model
.
UsageLog
,
*
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Where
(
"user_id = ?"
,
userID
)
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
logs
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
logs
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
func
(
r
*
UsageLogRepository
)
ListByApiKey
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
PaginationParams
)
([]
model
.
UsageLog
,
*
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Where
(
"api_key_id = ?"
,
apiKeyID
)
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
logs
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
logs
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
// UserStats 用户使用统计
type
UserStats
struct
{
TotalRequests
int64
`json:"total_requests"`
TotalTokens
int64
`json:"total_tokens"`
TotalCost
float64
`json:"total_cost"`
InputTokens
int64
`json:"input_tokens"`
OutputTokens
int64
`json:"output_tokens"`
CacheReadTokens
int64
`json:"cache_read_tokens"`
}
func
(
r
*
UsageLogRepository
)
GetUserStats
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
UserStats
,
error
)
{
var
stats
UserStats
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(actual_cost), 0) as total_cost,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens
`
)
.
Where
(
"user_id = ? AND created_at >= ? AND created_at < ?"
,
userID
,
startTime
,
endTime
)
.
Scan
(
&
stats
)
.
Error
return
&
stats
,
err
}
// DashboardStats 仪表盘统计
type
DashboardStats
struct
{
// 用户统计
TotalUsers
int64
`json:"total_users"`
TodayNewUsers
int64
`json:"today_new_users"`
// 今日新增用户数
ActiveUsers
int64
`json:"active_users"`
// 今日有请求的用户数
// API Key 统计
TotalApiKeys
int64
`json:"total_api_keys"`
ActiveApiKeys
int64
`json:"active_api_keys"`
// 状态为 active 的 API Key 数
// 账户统计
TotalAccounts
int64
`json:"total_accounts"`
NormalAccounts
int64
`json:"normal_accounts"`
// 正常账户数 (schedulable=true, status=active)
ErrorAccounts
int64
`json:"error_accounts"`
// 异常账户数 (status=error)
RateLimitAccounts
int64
`json:"ratelimit_accounts"`
// 限流账户数
OverloadAccounts
int64
`json:"overload_accounts"`
// 过载账户数
// 累计 Token 使用统计
TotalRequests
int64
`json:"total_requests"`
TotalInputTokens
int64
`json:"total_input_tokens"`
TotalOutputTokens
int64
`json:"total_output_tokens"`
TotalCacheCreationTokens
int64
`json:"total_cache_creation_tokens"`
TotalCacheReadTokens
int64
`json:"total_cache_read_tokens"`
TotalTokens
int64
`json:"total_tokens"`
TotalCost
float64
`json:"total_cost"`
// 累计标准计费
TotalActualCost
float64
`json:"total_actual_cost"`
// 累计实际扣除
// 今日 Token 使用统计
TodayRequests
int64
`json:"today_requests"`
TodayInputTokens
int64
`json:"today_input_tokens"`
TodayOutputTokens
int64
`json:"today_output_tokens"`
TodayCacheCreationTokens
int64
`json:"today_cache_creation_tokens"`
TodayCacheReadTokens
int64
`json:"today_cache_read_tokens"`
TodayTokens
int64
`json:"today_tokens"`
TodayCost
float64
`json:"today_cost"`
// 今日标准计费
TodayActualCost
float64
`json:"today_actual_cost"`
// 今日实际扣除
// 系统运行统计
AverageDurationMs
float64
`json:"average_duration_ms"`
// 平均响应时间
}
func
(
r
*
UsageLogRepository
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
DashboardStats
,
error
)
{
var
stats
DashboardStats
today
:=
timezone
.
Today
()
// 总用户数
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
User
{})
.
Count
(
&
stats
.
TotalUsers
)
// 今日新增用户数
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
User
{})
.
Where
(
"created_at >= ?"
,
today
)
.
Count
(
&
stats
.
TodayNewUsers
)
// 今日活跃用户数 (今日有请求的用户)
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Distinct
(
"user_id"
)
.
Where
(
"created_at >= ?"
,
today
)
.
Count
(
&
stats
.
ActiveUsers
)
// 总 API Key 数
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Count
(
&
stats
.
TotalApiKeys
)
// 活跃 API Key 数
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Count
(
&
stats
.
ActiveApiKeys
)
// 总账户数
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Count
(
&
stats
.
TotalAccounts
)
// 正常账户数 (schedulable=true, status=active)
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"status = ? AND schedulable = ?"
,
model
.
StatusActive
,
true
)
.
Count
(
&
stats
.
NormalAccounts
)
// 异常账户数 (status=error)
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"status = ?"
,
model
.
StatusError
)
.
Count
(
&
stats
.
ErrorAccounts
)
// 限流账户数
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?"
,
time
.
Now
())
.
Count
(
&
stats
.
RateLimitAccounts
)
// 过载账户数
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"overload_until IS NOT NULL AND overload_until > ?"
,
time
.
Now
())
.
Count
(
&
stats
.
OverloadAccounts
)
// 累计 Token 统计
var
totalStats
struct
{
TotalRequests
int64
`gorm:"column:total_requests"`
TotalInputTokens
int64
`gorm:"column:total_input_tokens"`
TotalOutputTokens
int64
`gorm:"column:total_output_tokens"`
TotalCacheCreationTokens
int64
`gorm:"column:total_cache_creation_tokens"`
TotalCacheReadTokens
int64
`gorm:"column:total_cache_read_tokens"`
TotalCost
float64
`gorm:"column:total_cost"`
TotalActualCost
float64
`gorm:"column:total_actual_cost"`
AverageDurationMs
float64
`gorm:"column:avg_duration_ms"`
}
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
`
)
.
Scan
(
&
totalStats
)
stats
.
TotalRequests
=
totalStats
.
TotalRequests
stats
.
TotalInputTokens
=
totalStats
.
TotalInputTokens
stats
.
TotalOutputTokens
=
totalStats
.
TotalOutputTokens
stats
.
TotalCacheCreationTokens
=
totalStats
.
TotalCacheCreationTokens
stats
.
TotalCacheReadTokens
=
totalStats
.
TotalCacheReadTokens
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheCreationTokens
+
stats
.
TotalCacheReadTokens
stats
.
TotalCost
=
totalStats
.
TotalCost
stats
.
TotalActualCost
=
totalStats
.
TotalActualCost
stats
.
AverageDurationMs
=
totalStats
.
AverageDurationMs
// 今日 Token 统计
var
todayStats
struct
{
TodayRequests
int64
`gorm:"column:today_requests"`
TodayInputTokens
int64
`gorm:"column:today_input_tokens"`
TodayOutputTokens
int64
`gorm:"column:today_output_tokens"`
TodayCacheCreationTokens
int64
`gorm:"column:today_cache_creation_tokens"`
TodayCacheReadTokens
int64
`gorm:"column:today_cache_read_tokens"`
TodayCost
float64
`gorm:"column:today_cost"`
TodayActualCost
float64
`gorm:"column:today_actual_cost"`
}
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
`
)
.
Where
(
"created_at >= ?"
,
today
)
.
Scan
(
&
todayStats
)
stats
.
TodayRequests
=
todayStats
.
TodayRequests
stats
.
TodayInputTokens
=
todayStats
.
TodayInputTokens
stats
.
TodayOutputTokens
=
todayStats
.
TodayOutputTokens
stats
.
TodayCacheCreationTokens
=
todayStats
.
TodayCacheCreationTokens
stats
.
TodayCacheReadTokens
=
todayStats
.
TodayCacheReadTokens
stats
.
TodayTokens
=
stats
.
TodayInputTokens
+
stats
.
TodayOutputTokens
+
stats
.
TodayCacheCreationTokens
+
stats
.
TodayCacheReadTokens
stats
.
TodayCost
=
todayStats
.
TodayCost
stats
.
TodayActualCost
=
todayStats
.
TodayActualCost
return
&
stats
,
nil
}
func
(
r
*
UsageLogRepository
)
ListByAccount
(
ctx
context
.
Context
,
accountID
int64
,
params
PaginationParams
)
([]
model
.
UsageLog
,
*
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Where
(
"account_id = ?"
,
accountID
)
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
logs
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
logs
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
func
(
r
*
UsageLogRepository
)
ListByUserAndTimeRange
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"user_id = ? AND created_at >= ? AND created_at < ?"
,
userID
,
startTime
,
endTime
)
.
Order
(
"id DESC"
)
.
Find
(
&
logs
)
.
Error
return
logs
,
nil
,
err
}
func
(
r
*
UsageLogRepository
)
ListByApiKeyAndTimeRange
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"api_key_id = ? AND created_at >= ? AND created_at < ?"
,
apiKeyID
,
startTime
,
endTime
)
.
Order
(
"id DESC"
)
.
Find
(
&
logs
)
.
Error
return
logs
,
nil
,
err
}
func
(
r
*
UsageLogRepository
)
ListByAccountAndTimeRange
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ? AND created_at >= ? AND created_at < ?"
,
accountID
,
startTime
,
endTime
)
.
Order
(
"id DESC"
)
.
Find
(
&
logs
)
.
Error
return
logs
,
nil
,
err
}
func
(
r
*
UsageLogRepository
)
ListByModelAndTimeRange
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"model = ? AND created_at >= ? AND created_at < ?"
,
modelName
,
startTime
,
endTime
)
.
Order
(
"id DESC"
)
.
Find
(
&
logs
)
.
Error
return
logs
,
nil
,
err
}
func
(
r
*
UsageLogRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
UsageLog
{},
id
)
.
Error
}
// AccountStats 账号使用统计
type
AccountStats
struct
{
Requests
int64
`json:"requests"`
Tokens
int64
`json:"tokens"`
Cost
float64
`json:"cost"`
}
// GetAccountTodayStats 获取账号今日统计
func
(
r
*
UsageLogRepository
)
GetAccountTodayStats
(
ctx
context
.
Context
,
accountID
int64
)
(
*
AccountStats
,
error
)
{
today
:=
timezone
.
Today
()
var
stats
struct
{
Requests
int64
`gorm:"column:requests"`
Tokens
int64
`gorm:"column:tokens"`
Cost
float64
`gorm:"column:cost"`
}
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
`
)
.
Where
(
"account_id = ? AND created_at >= ?"
,
accountID
,
today
)
.
Scan
(
&
stats
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
AccountStats
{
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
},
nil
}
// GetAccountWindowStats 获取账号时间窗口内的统计
func
(
r
*
UsageLogRepository
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
AccountStats
,
error
)
{
var
stats
struct
{
Requests
int64
`gorm:"column:requests"`
Tokens
int64
`gorm:"column:tokens"`
Cost
float64
`gorm:"column:cost"`
}
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
`
)
.
Where
(
"account_id = ? AND created_at >= ?"
,
accountID
,
startTime
)
.
Scan
(
&
stats
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
AccountStats
{
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
},
nil
}
// TrendDataPoint represents a single point in trend data
type
TrendDataPoint
struct
{
Date
string
`json:"date"`
Requests
int64
`json:"requests"`
InputTokens
int64
`json:"input_tokens"`
OutputTokens
int64
`json:"output_tokens"`
CacheTokens
int64
`json:"cache_tokens"`
TotalTokens
int64
`json:"total_tokens"`
Cost
float64
`json:"cost"`
// 标准计费
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
}
// ModelStat represents usage statistics for a single model
type
ModelStat
struct
{
Model
string
`json:"model"`
Requests
int64
`json:"requests"`
InputTokens
int64
`json:"input_tokens"`
OutputTokens
int64
`json:"output_tokens"`
TotalTokens
int64
`json:"total_tokens"`
Cost
float64
`json:"cost"`
// 标准计费
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
type
UserUsageTrendPoint
struct
{
Date
string
`json:"date"`
UserID
int64
`json:"user_id"`
Email
string
`json:"email"`
Requests
int64
`json:"requests"`
Tokens
int64
`json:"tokens"`
Cost
float64
`json:"cost"`
// 标准计费
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
}
// GetUsageTrend returns usage trend data grouped by date
// granularity: "day" or "hour"
func
(
r
*
UsageLogRepository
)
GetUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
)
([]
TrendDataPoint
,
error
)
{
var
results
[]
TrendDataPoint
// Choose date format based on granularity
var
dateFormat
string
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
else
{
dateFormat
=
"YYYY-MM-DD"
}
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`
,
dateFormat
)
.
Where
(
"created_at >= ? AND created_at < ?"
,
startTime
,
endTime
)
.
Group
(
"date"
)
.
Order
(
"date ASC"
)
.
Scan
(
&
results
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
results
,
nil
}
// GetModelStats returns usage statistics grouped by model
func
(
r
*
UsageLogRepository
)
GetModelStats
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
([]
ModelStat
,
error
)
{
var
results
[]
ModelStat
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`
)
.
Where
(
"created_at >= ? AND created_at < ?"
,
startTime
,
endTime
)
.
Group
(
"model"
)
.
Order
(
"total_tokens DESC"
)
.
Scan
(
&
results
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
results
,
nil
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type
ApiKeyUsageTrendPoint
struct
{
Date
string
`json:"date"`
ApiKeyID
int64
`json:"api_key_id"`
KeyName
string
`json:"key_name"`
Requests
int64
`json:"requests"`
Tokens
int64
`json:"tokens"`
}
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func
(
r
*
UsageLogRepository
)
GetApiKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
ApiKeyUsageTrendPoint
,
error
)
{
var
results
[]
ApiKeyUsageTrendPoint
// Choose date format based on granularity
var
dateFormat
string
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
else
{
dateFormat
=
"YYYY-MM-DD"
}
// Use raw SQL for complex subquery
query
:=
`
WITH top_keys AS (
SELECT api_key_id
FROM usage_logs
WHERE created_at >= ? AND created_at < ?
GROUP BY api_key_id
ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC
LIMIT ?
)
SELECT
TO_CHAR(u.created_at, '`
+
dateFormat
+
`') as date,
u.api_key_id,
COALESCE(k.name, '') as key_name,
COUNT(*) as requests,
COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens
FROM usage_logs u
LEFT JOIN api_keys k ON u.api_key_id = k.id
WHERE u.api_key_id IN (SELECT api_key_id FROM top_keys)
AND u.created_at >= ? AND u.created_at < ?
GROUP BY date, u.api_key_id, k.name
ORDER BY date ASC, tokens DESC
`
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Raw
(
query
,
startTime
,
endTime
,
limit
,
startTime
,
endTime
)
.
Scan
(
&
results
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
results
,
nil
}
// GetUserUsageTrend returns usage trend data grouped by user and date
func
(
r
*
UsageLogRepository
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
UserUsageTrendPoint
,
error
)
{
var
results
[]
UserUsageTrendPoint
// Choose date format based on granularity
var
dateFormat
string
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
else
{
dateFormat
=
"YYYY-MM-DD"
}
// Use raw SQL for complex subquery
query
:=
`
WITH top_users AS (
SELECT user_id
FROM usage_logs
WHERE created_at >= ? AND created_at < ?
GROUP BY user_id
ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC
LIMIT ?
)
SELECT
TO_CHAR(u.created_at, '`
+
dateFormat
+
`') as date,
u.user_id,
COALESCE(us.email, '') as email,
COUNT(*) as requests,
COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens,
COALESCE(SUM(u.total_cost), 0) as cost,
COALESCE(SUM(u.actual_cost), 0) as actual_cost
FROM usage_logs u
LEFT JOIN users us ON u.user_id = us.id
WHERE u.user_id IN (SELECT user_id FROM top_users)
AND u.created_at >= ? AND u.created_at < ?
GROUP BY date, u.user_id, us.email
ORDER BY date ASC, tokens DESC
`
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Raw
(
query
,
startTime
,
endTime
,
limit
,
startTime
,
endTime
)
.
Scan
(
&
results
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
results
,
nil
}
// UserDashboardStats 用户仪表盘统计
type
UserDashboardStats
struct
{
// API Key 统计
TotalApiKeys
int64
`json:"total_api_keys"`
ActiveApiKeys
int64
`json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests
int64
`json:"total_requests"`
TotalInputTokens
int64
`json:"total_input_tokens"`
TotalOutputTokens
int64
`json:"total_output_tokens"`
TotalCacheCreationTokens
int64
`json:"total_cache_creation_tokens"`
TotalCacheReadTokens
int64
`json:"total_cache_read_tokens"`
TotalTokens
int64
`json:"total_tokens"`
TotalCost
float64
`json:"total_cost"`
// 累计标准计费
TotalActualCost
float64
`json:"total_actual_cost"`
// 累计实际扣除
// 今日 Token 使用统计
TodayRequests
int64
`json:"today_requests"`
TodayInputTokens
int64
`json:"today_input_tokens"`
TodayOutputTokens
int64
`json:"today_output_tokens"`
TodayCacheCreationTokens
int64
`json:"today_cache_creation_tokens"`
TodayCacheReadTokens
int64
`json:"today_cache_read_tokens"`
TodayTokens
int64
`json:"today_tokens"`
TodayCost
float64
`json:"today_cost"`
// 今日标准计费
TodayActualCost
float64
`json:"today_actual_cost"`
// 今日实际扣除
// 性能统计
AverageDurationMs
float64
`json:"average_duration_ms"`
}
// GetUserDashboardStats 获取用户专属的仪表盘统计
func
(
r
*
UsageLogRepository
)
GetUserDashboardStats
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserDashboardStats
,
error
)
{
var
stats
UserDashboardStats
today
:=
timezone
.
Today
()
// API Key 统计
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"user_id = ?"
,
userID
)
.
Count
(
&
stats
.
TotalApiKeys
)
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"user_id = ? AND status = ?"
,
userID
,
model
.
StatusActive
)
.
Count
(
&
stats
.
ActiveApiKeys
)
// 累计 Token 统计
var
totalStats
struct
{
TotalRequests
int64
`gorm:"column:total_requests"`
TotalInputTokens
int64
`gorm:"column:total_input_tokens"`
TotalOutputTokens
int64
`gorm:"column:total_output_tokens"`
TotalCacheCreationTokens
int64
`gorm:"column:total_cache_creation_tokens"`
TotalCacheReadTokens
int64
`gorm:"column:total_cache_read_tokens"`
TotalCost
float64
`gorm:"column:total_cost"`
TotalActualCost
float64
`gorm:"column:total_actual_cost"`
AverageDurationMs
float64
`gorm:"column:avg_duration_ms"`
}
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
`
)
.
Where
(
"user_id = ?"
,
userID
)
.
Scan
(
&
totalStats
)
stats
.
TotalRequests
=
totalStats
.
TotalRequests
stats
.
TotalInputTokens
=
totalStats
.
TotalInputTokens
stats
.
TotalOutputTokens
=
totalStats
.
TotalOutputTokens
stats
.
TotalCacheCreationTokens
=
totalStats
.
TotalCacheCreationTokens
stats
.
TotalCacheReadTokens
=
totalStats
.
TotalCacheReadTokens
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheCreationTokens
+
stats
.
TotalCacheReadTokens
stats
.
TotalCost
=
totalStats
.
TotalCost
stats
.
TotalActualCost
=
totalStats
.
TotalActualCost
stats
.
AverageDurationMs
=
totalStats
.
AverageDurationMs
// 今日 Token 统计
var
todayStats
struct
{
TodayRequests
int64
`gorm:"column:today_requests"`
TodayInputTokens
int64
`gorm:"column:today_input_tokens"`
TodayOutputTokens
int64
`gorm:"column:today_output_tokens"`
TodayCacheCreationTokens
int64
`gorm:"column:today_cache_creation_tokens"`
TodayCacheReadTokens
int64
`gorm:"column:today_cache_read_tokens"`
TodayCost
float64
`gorm:"column:today_cost"`
TodayActualCost
float64
`gorm:"column:today_actual_cost"`
}
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
`
)
.
Where
(
"user_id = ? AND created_at >= ?"
,
userID
,
today
)
.
Scan
(
&
todayStats
)
stats
.
TodayRequests
=
todayStats
.
TodayRequests
stats
.
TodayInputTokens
=
todayStats
.
TodayInputTokens
stats
.
TodayOutputTokens
=
todayStats
.
TodayOutputTokens
stats
.
TodayCacheCreationTokens
=
todayStats
.
TodayCacheCreationTokens
stats
.
TodayCacheReadTokens
=
todayStats
.
TodayCacheReadTokens
stats
.
TodayTokens
=
stats
.
TodayInputTokens
+
stats
.
TodayOutputTokens
+
stats
.
TodayCacheCreationTokens
+
stats
.
TodayCacheReadTokens
stats
.
TodayCost
=
todayStats
.
TodayCost
stats
.
TodayActualCost
=
todayStats
.
TodayActualCost
return
&
stats
,
nil
}
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func
(
r
*
UsageLogRepository
)
GetUserUsageTrendByUserID
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
,
granularity
string
)
([]
TrendDataPoint
,
error
)
{
var
results
[]
TrendDataPoint
var
dateFormat
string
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
else
{
dateFormat
=
"YYYY-MM-DD"
}
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`
,
dateFormat
)
.
Where
(
"user_id = ? AND created_at >= ? AND created_at < ?"
,
userID
,
startTime
,
endTime
)
.
Group
(
"date"
)
.
Order
(
"date ASC"
)
.
Scan
(
&
results
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
results
,
nil
}
// GetUserModelStats 获取指定用户的模型统计
func
(
r
*
UsageLogRepository
)
GetUserModelStats
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
ModelStat
,
error
)
{
var
results
[]
ModelStat
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`
)
.
Where
(
"user_id = ? AND created_at >= ? AND created_at < ?"
,
userID
,
startTime
,
endTime
)
.
Group
(
"model"
)
.
Order
(
"total_tokens DESC"
)
.
Scan
(
&
results
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
results
,
nil
}
// UsageLogFilters represents filters for usage log queries
type
UsageLogFilters
struct
{
UserID
int64
ApiKeyID
int64
StartTime
*
time
.
Time
EndTime
*
time
.
Time
}
// ListWithFilters lists usage logs with optional filters (for admin)
func
(
r
*
UsageLogRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
PaginationParams
,
filters
UsageLogFilters
)
([]
model
.
UsageLog
,
*
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
// Apply filters
if
filters
.
UserID
>
0
{
db
=
db
.
Where
(
"user_id = ?"
,
filters
.
UserID
)
}
if
filters
.
ApiKeyID
>
0
{
db
=
db
.
Where
(
"api_key_id = ?"
,
filters
.
ApiKeyID
)
}
if
filters
.
StartTime
!=
nil
{
db
=
db
.
Where
(
"created_at >= ?"
,
*
filters
.
StartTime
)
}
if
filters
.
EndTime
!=
nil
{
db
=
db
.
Where
(
"created_at <= ?"
,
*
filters
.
EndTime
)
}
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
// Preload user and api_key for display
if
err
:=
db
.
Preload
(
"User"
)
.
Preload
(
"ApiKey"
)
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
logs
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
logs
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
// UsageStats represents usage statistics
type
UsageStats
struct
{
TotalRequests
int64
`json:"total_requests"`
TotalInputTokens
int64
`json:"total_input_tokens"`
TotalOutputTokens
int64
`json:"total_output_tokens"`
TotalCacheTokens
int64
`json:"total_cache_tokens"`
TotalTokens
int64
`json:"total_tokens"`
TotalCost
float64
`json:"total_cost"`
TotalActualCost
float64
`json:"total_actual_cost"`
AverageDurationMs
float64
`json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
type
BatchUserUsageStats
struct
{
UserID
int64
`json:"user_id"`
TodayActualCost
float64
`json:"today_actual_cost"`
TotalActualCost
float64
`json:"total_actual_cost"`
}
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func
(
r
*
UsageLogRepository
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
BatchUserUsageStats
,
error
)
{
if
len
(
userIDs
)
==
0
{
return
make
(
map
[
int64
]
*
BatchUserUsageStats
),
nil
}
today
:=
timezone
.
Today
()
result
:=
make
(
map
[
int64
]
*
BatchUserUsageStats
)
// Initialize result map
for
_
,
id
:=
range
userIDs
{
result
[
id
]
=
&
BatchUserUsageStats
{
UserID
:
id
}
}
// Get total actual_cost per user
var
totalStats
[]
struct
{
UserID
int64
`gorm:"column:user_id"`
TotalCost
float64
`gorm:"column:total_cost"`
}
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
"user_id, COALESCE(SUM(actual_cost), 0) as total_cost"
)
.
Where
(
"user_id IN ?"
,
userIDs
)
.
Group
(
"user_id"
)
.
Scan
(
&
totalStats
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
for
_
,
stat
:=
range
totalStats
{
if
s
,
ok
:=
result
[
stat
.
UserID
];
ok
{
s
.
TotalActualCost
=
stat
.
TotalCost
}
}
// Get today actual_cost per user
var
todayStats
[]
struct
{
UserID
int64
`gorm:"column:user_id"`
TodayCost
float64
`gorm:"column:today_cost"`
}
err
=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
"user_id, COALESCE(SUM(actual_cost), 0) as today_cost"
)
.
Where
(
"user_id IN ? AND created_at >= ?"
,
userIDs
,
today
)
.
Group
(
"user_id"
)
.
Scan
(
&
todayStats
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
for
_
,
stat
:=
range
todayStats
{
if
s
,
ok
:=
result
[
stat
.
UserID
];
ok
{
s
.
TodayActualCost
=
stat
.
TodayCost
}
}
return
result
,
nil
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type
BatchApiKeyUsageStats
struct
{
ApiKeyID
int64
`json:"api_key_id"`
TodayActualCost
float64
`json:"today_actual_cost"`
TotalActualCost
float64
`json:"total_actual_cost"`
}
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func
(
r
*
UsageLogRepository
)
GetBatchApiKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
BatchApiKeyUsageStats
,
error
)
{
if
len
(
apiKeyIDs
)
==
0
{
return
make
(
map
[
int64
]
*
BatchApiKeyUsageStats
),
nil
}
today
:=
timezone
.
Today
()
result
:=
make
(
map
[
int64
]
*
BatchApiKeyUsageStats
)
// Initialize result map
for
_
,
id
:=
range
apiKeyIDs
{
result
[
id
]
=
&
BatchApiKeyUsageStats
{
ApiKeyID
:
id
}
}
// Get total actual_cost per api key
var
totalStats
[]
struct
{
ApiKeyID
int64
`gorm:"column:api_key_id"`
TotalCost
float64
`gorm:"column:total_cost"`
}
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
"api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost"
)
.
Where
(
"api_key_id IN ?"
,
apiKeyIDs
)
.
Group
(
"api_key_id"
)
.
Scan
(
&
totalStats
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
for
_
,
stat
:=
range
totalStats
{
if
s
,
ok
:=
result
[
stat
.
ApiKeyID
];
ok
{
s
.
TotalActualCost
=
stat
.
TotalCost
}
}
// Get today actual_cost per api key
var
todayStats
[]
struct
{
ApiKeyID
int64
`gorm:"column:api_key_id"`
TodayCost
float64
`gorm:"column:today_cost"`
}
err
=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
"api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost"
)
.
Where
(
"api_key_id IN ? AND created_at >= ?"
,
apiKeyIDs
,
today
)
.
Group
(
"api_key_id"
)
.
Scan
(
&
todayStats
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
for
_
,
stat
:=
range
todayStats
{
if
s
,
ok
:=
result
[
stat
.
ApiKeyID
];
ok
{
s
.
TodayActualCost
=
stat
.
TodayCost
}
}
return
result
,
nil
}
// GetGlobalStats gets usage statistics for all users within a time range
func
(
r
*
UsageLogRepository
)
GetGlobalStats
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
(
*
UsageStats
,
error
)
{
var
stats
struct
{
TotalRequests
int64
`gorm:"column:total_requests"`
TotalInputTokens
int64
`gorm:"column:total_input_tokens"`
TotalOutputTokens
int64
`gorm:"column:total_output_tokens"`
TotalCacheTokens
int64
`gorm:"column:total_cache_tokens"`
TotalCost
float64
`gorm:"column:total_cost"`
TotalActualCost
float64
`gorm:"column:total_actual_cost"`
AverageDurationMs
float64
`gorm:"column:avg_duration_ms"`
}
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
`
)
.
Where
(
"created_at >= ? AND created_at <= ?"
,
startTime
,
endTime
)
.
Scan
(
&
stats
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
UsageStats
{
TotalRequests
:
stats
.
TotalRequests
,
TotalInputTokens
:
stats
.
TotalInputTokens
,
TotalOutputTokens
:
stats
.
TotalOutputTokens
,
TotalCacheTokens
:
stats
.
TotalCacheTokens
,
TotalTokens
:
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheTokens
,
TotalCost
:
stats
.
TotalCost
,
TotalActualCost
:
stats
.
TotalActualCost
,
AverageDurationMs
:
stats
.
AverageDurationMs
,
},
nil
}
backend/internal/repository/user_repo.go
0 → 100644
View file @
642842c2
package
repository
import
(
"context"
"sub2api/internal/model"
"gorm.io/gorm"
)
type
UserRepository
struct
{
db
*
gorm
.
DB
}
func
NewUserRepository
(
db
*
gorm
.
DB
)
*
UserRepository
{
return
&
UserRepository
{
db
:
db
}
}
func
(
r
*
UserRepository
)
Create
(
ctx
context
.
Context
,
user
*
model
.
User
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
user
)
.
Error
}
func
(
r
*
UserRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
User
,
error
)
{
var
user
model
.
User
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
user
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
user
,
nil
}
func
(
r
*
UserRepository
)
GetByEmail
(
ctx
context
.
Context
,
email
string
)
(
*
model
.
User
,
error
)
{
var
user
model
.
User
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"email = ?"
,
email
)
.
First
(
&
user
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
user
,
nil
}
func
(
r
*
UserRepository
)
Update
(
ctx
context
.
Context
,
user
*
model
.
User
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
user
)
.
Error
}
func
(
r
*
UserRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
User
{},
id
)
.
Error
}
func
(
r
*
UserRepository
)
List
(
ctx
context
.
Context
,
params
PaginationParams
)
([]
model
.
User
,
*
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
)
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func
(
r
*
UserRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
PaginationParams
,
status
,
role
,
search
string
)
([]
model
.
User
,
*
PaginationResult
,
error
)
{
var
users
[]
model
.
User
var
total
int64
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
User
{})
// Apply filters
if
status
!=
""
{
db
=
db
.
Where
(
"status = ?"
,
status
)
}
if
role
!=
""
{
db
=
db
.
Where
(
"role = ?"
,
role
)
}
if
search
!=
""
{
searchPattern
:=
"%"
+
search
+
"%"
db
=
db
.
Where
(
"email ILIKE ?"
,
searchPattern
)
}
if
err
:=
db
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
if
err
:=
db
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
"id DESC"
)
.
Find
(
&
users
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
users
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
func
(
r
*
UserRepository
)
UpdateBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
User
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"balance"
,
gorm
.
Expr
(
"balance + ?"
,
amount
))
.
Error
}
// DeductBalance 扣减用户余额,仅当余额充足时执行
func
(
r
*
UserRepository
)
DeductBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
User
{})
.
Where
(
"id = ? AND balance >= ?"
,
id
,
amount
)
.
Update
(
"balance"
,
gorm
.
Expr
(
"balance - ?"
,
amount
))
if
result
.
Error
!=
nil
{
return
result
.
Error
}
if
result
.
RowsAffected
==
0
{
return
gorm
.
ErrRecordNotFound
// 余额不足或用户不存在
}
return
nil
}
func
(
r
*
UserRepository
)
UpdateConcurrency
(
ctx
context
.
Context
,
id
int64
,
amount
int
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
User
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"concurrency"
,
gorm
.
Expr
(
"concurrency + ?"
,
amount
))
.
Error
}
func
(
r
*
UserRepository
)
ExistsByEmail
(
ctx
context
.
Context
,
email
string
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
User
{})
.
Where
(
"email = ?"
,
email
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
}
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数
func
(
r
*
UserRepository
)
RemoveGroupFromAllowedGroups
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
User
{})
.
Where
(
"? = ANY(allowed_groups)"
,
groupID
)
.
Update
(
"allowed_groups"
,
gorm
.
Expr
(
"array_remove(allowed_groups, ?)"
,
groupID
))
return
result
.
RowsAffected
,
result
.
Error
}
backend/internal/repository/user_subscription_repo.go
0 → 100644
View file @
642842c2
package
repository
import
(
"context"
"time"
"sub2api/internal/model"
"gorm.io/gorm"
)
// UserSubscriptionRepository 用户订阅仓库
type
UserSubscriptionRepository
struct
{
db
*
gorm
.
DB
}
// NewUserSubscriptionRepository 创建用户订阅仓库
func
NewUserSubscriptionRepository
(
db
*
gorm
.
DB
)
*
UserSubscriptionRepository
{
return
&
UserSubscriptionRepository
{
db
:
db
}
}
// Create 创建订阅
func
(
r
*
UserSubscriptionRepository
)
Create
(
ctx
context
.
Context
,
sub
*
model
.
UserSubscription
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
sub
)
.
Error
}
// GetByID 根据ID获取订阅
func
(
r
*
UserSubscriptionRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
UserSubscription
,
error
)
{
var
sub
model
.
UserSubscription
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Preload
(
"AssignedByUser"
)
.
First
(
&
sub
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
sub
,
nil
}
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
func
(
r
*
UserSubscriptionRepository
)
GetByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
model
.
UserSubscription
,
error
)
{
var
sub
model
.
UserSubscription
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"Group"
)
.
Where
(
"user_id = ? AND group_id = ?"
,
userID
,
groupID
)
.
First
(
&
sub
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
sub
,
nil
}
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
func
(
r
*
UserSubscriptionRepository
)
GetActiveByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
model
.
UserSubscription
,
error
)
{
var
sub
model
.
UserSubscription
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"Group"
)
.
Where
(
"user_id = ? AND group_id = ? AND status = ? AND expires_at > ?"
,
userID
,
groupID
,
model
.
SubscriptionStatusActive
,
time
.
Now
())
.
First
(
&
sub
)
.
Error
if
err
!=
nil
{
return
nil
,
err
}
return
&
sub
,
nil
}
// Update 更新订阅
func
(
r
*
UserSubscriptionRepository
)
Update
(
ctx
context
.
Context
,
sub
*
model
.
UserSubscription
)
error
{
sub
.
UpdatedAt
=
time
.
Now
()
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
sub
)
.
Error
}
// Delete 删除订阅
func
(
r
*
UserSubscriptionRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
UserSubscription
{},
id
)
.
Error
}
// ListByUserID 获取用户的所有订阅
func
(
r
*
UserSubscriptionRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
model
.
UserSubscription
,
error
)
{
var
subs
[]
model
.
UserSubscription
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"Group"
)
.
Where
(
"user_id = ?"
,
userID
)
.
Order
(
"created_at DESC"
)
.
Find
(
&
subs
)
.
Error
return
subs
,
err
}
// ListActiveByUserID 获取用户的所有有效订阅
func
(
r
*
UserSubscriptionRepository
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
model
.
UserSubscription
,
error
)
{
var
subs
[]
model
.
UserSubscription
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"Group"
)
.
Where
(
"user_id = ? AND status = ? AND expires_at > ?"
,
userID
,
model
.
SubscriptionStatusActive
,
time
.
Now
())
.
Order
(
"created_at DESC"
)
.
Find
(
&
subs
)
.
Error
return
subs
,
err
}
// ListByGroupID 获取分组的所有订阅(分页)
func
(
r
*
UserSubscriptionRepository
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
PaginationParams
)
([]
model
.
UserSubscription
,
*
PaginationResult
,
error
)
{
var
subs
[]
model
.
UserSubscription
var
total
int64
query
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"group_id = ?"
,
groupID
)
if
err
:=
query
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
err
:=
query
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Order
(
"created_at DESC"
)
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Find
(
&
subs
)
.
Error
if
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
subs
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
// List 获取所有订阅(分页,支持筛选)
func
(
r
*
UserSubscriptionRepository
)
List
(
ctx
context
.
Context
,
params
PaginationParams
,
userID
,
groupID
*
int64
,
status
string
)
([]
model
.
UserSubscription
,
*
PaginationResult
,
error
)
{
var
subs
[]
model
.
UserSubscription
var
total
int64
query
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
if
userID
!=
nil
{
query
=
query
.
Where
(
"user_id = ?"
,
*
userID
)
}
if
groupID
!=
nil
{
query
=
query
.
Where
(
"group_id = ?"
,
*
groupID
)
}
if
status
!=
""
{
query
=
query
.
Where
(
"status = ?"
,
status
)
}
if
err
:=
query
.
Count
(
&
total
)
.
Error
;
err
!=
nil
{
return
nil
,
nil
,
err
}
err
:=
query
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Preload
(
"AssignedByUser"
)
.
Order
(
"created_at DESC"
)
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Find
(
&
subs
)
.
Error
if
err
!=
nil
{
return
nil
,
nil
,
err
}
pages
:=
int
(
total
)
/
params
.
Limit
()
if
int
(
total
)
%
params
.
Limit
()
>
0
{
pages
++
}
return
subs
,
&
PaginationResult
{
Total
:
total
,
Page
:
params
.
Page
,
PageSize
:
params
.
Limit
(),
Pages
:
pages
,
},
nil
}
// IncrementUsage 增加使用量
func
(
r
*
UserSubscriptionRepository
)
IncrementUsage
(
ctx
context
.
Context
,
id
int64
,
costUSD
float64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"daily_usage_usd"
:
gorm
.
Expr
(
"daily_usage_usd + ?"
,
costUSD
),
"weekly_usage_usd"
:
gorm
.
Expr
(
"weekly_usage_usd + ?"
,
costUSD
),
"monthly_usage_usd"
:
gorm
.
Expr
(
"monthly_usage_usd + ?"
,
costUSD
),
"updated_at"
:
time
.
Now
(),
})
.
Error
}
// ResetDailyUsage 重置日使用量
func
(
r
*
UserSubscriptionRepository
)
ResetDailyUsage
(
ctx
context
.
Context
,
id
int64
,
newWindowStart
time
.
Time
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"daily_usage_usd"
:
0
,
"daily_window_start"
:
newWindowStart
,
"updated_at"
:
time
.
Now
(),
})
.
Error
}
// ResetWeeklyUsage 重置周使用量
func
(
r
*
UserSubscriptionRepository
)
ResetWeeklyUsage
(
ctx
context
.
Context
,
id
int64
,
newWindowStart
time
.
Time
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"weekly_usage_usd"
:
0
,
"weekly_window_start"
:
newWindowStart
,
"updated_at"
:
time
.
Now
(),
})
.
Error
}
// ResetMonthlyUsage 重置月使用量
func
(
r
*
UserSubscriptionRepository
)
ResetMonthlyUsage
(
ctx
context
.
Context
,
id
int64
,
newWindowStart
time
.
Time
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"monthly_usage_usd"
:
0
,
"monthly_window_start"
:
newWindowStart
,
"updated_at"
:
time
.
Now
(),
})
.
Error
}
// ActivateWindows 激活所有窗口(首次使用时)
func
(
r
*
UserSubscriptionRepository
)
ActivateWindows
(
ctx
context
.
Context
,
id
int64
,
activateTime
time
.
Time
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"daily_window_start"
:
activateTime
,
"weekly_window_start"
:
activateTime
,
"monthly_window_start"
:
activateTime
,
"updated_at"
:
time
.
Now
(),
})
.
Error
}
// UpdateStatus 更新订阅状态
func
(
r
*
UserSubscriptionRepository
)
UpdateStatus
(
ctx
context
.
Context
,
id
int64
,
status
string
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"status"
:
status
,
"updated_at"
:
time
.
Now
(),
})
.
Error
}
// ExtendExpiry 延长订阅过期时间
func
(
r
*
UserSubscriptionRepository
)
ExtendExpiry
(
ctx
context
.
Context
,
id
int64
,
newExpiresAt
time
.
Time
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"expires_at"
:
newExpiresAt
,
"updated_at"
:
time
.
Now
(),
})
.
Error
}
// UpdateNotes 更新订阅备注
func
(
r
*
UserSubscriptionRepository
)
UpdateNotes
(
ctx
context
.
Context
,
id
int64
,
notes
string
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
interface
{}{
"notes"
:
notes
,
"updated_at"
:
time
.
Now
(),
})
.
Error
}
// ListExpired 获取所有已过期但状态仍为active的订阅
func
(
r
*
UserSubscriptionRepository
)
ListExpired
(
ctx
context
.
Context
)
([]
model
.
UserSubscription
,
error
)
{
var
subs
[]
model
.
UserSubscription
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ? AND expires_at <= ?"
,
model
.
SubscriptionStatusActive
,
time
.
Now
())
.
Find
(
&
subs
)
.
Error
return
subs
,
err
}
// BatchUpdateExpiredStatus 批量更新过期订阅状态
func
(
r
*
UserSubscriptionRepository
)
BatchUpdateExpiredStatus
(
ctx
context
.
Context
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"status = ? AND expires_at <= ?"
,
model
.
SubscriptionStatusActive
,
time
.
Now
())
.
Updates
(
map
[
string
]
interface
{}{
"status"
:
model
.
SubscriptionStatusExpired
,
"updated_at"
:
time
.
Now
(),
})
return
result
.
RowsAffected
,
result
.
Error
}
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
func
(
r
*
UserSubscriptionRepository
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"user_id = ? AND group_id = ?"
,
userID
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
}
// CountByGroupID 获取分组的订阅数量
func
(
r
*
UserSubscriptionRepository
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
// CountActiveByGroupID 获取分组的有效订阅数量
func
(
r
*
UserSubscriptionRepository
)
CountActiveByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"group_id = ? AND status = ? AND expires_at > ?"
,
groupID
,
model
.
SubscriptionStatusActive
,
time
.
Now
())
.
Count
(
&
count
)
.
Error
return
count
,
err
}
// DeleteByGroupID 删除分组相关的所有订阅记录
func
(
r
*
UserSubscriptionRepository
)
DeleteByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Delete
(
&
model
.
UserSubscription
{})
return
result
.
RowsAffected
,
result
.
Error
}
backend/internal/service/account_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var
(
ErrAccountNotFound
=
errors
.
New
(
"account not found"
)
)
// CreateAccountRequest 创建账号请求
type
CreateAccountRequest
struct
{
Name
string
`json:"name"`
Platform
string
`json:"platform"`
Type
string
`json:"type"`
Credentials
map
[
string
]
interface
{}
`json:"credentials"`
Extra
map
[
string
]
interface
{}
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
int
`json:"concurrency"`
Priority
int
`json:"priority"`
GroupIDs
[]
int64
`json:"group_ids"`
}
// UpdateAccountRequest 更新账号请求
type
UpdateAccountRequest
struct
{
Name
*
string
`json:"name"`
Credentials
*
map
[
string
]
interface
{}
`json:"credentials"`
Extra
*
map
[
string
]
interface
{}
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
*
int
`json:"concurrency"`
Priority
*
int
`json:"priority"`
Status
*
string
`json:"status"`
GroupIDs
*
[]
int64
`json:"group_ids"`
}
// AccountService 账号管理服务
type
AccountService
struct
{
accountRepo
*
repository
.
AccountRepository
groupRepo
*
repository
.
GroupRepository
}
// NewAccountService 创建账号服务实例
func
NewAccountService
(
accountRepo
*
repository
.
AccountRepository
,
groupRepo
*
repository
.
GroupRepository
)
*
AccountService
{
return
&
AccountService
{
accountRepo
:
accountRepo
,
groupRepo
:
groupRepo
,
}
}
// Create 创建账号
func
(
s
*
AccountService
)
Create
(
ctx
context
.
Context
,
req
CreateAccountRequest
)
(
*
model
.
Account
,
error
)
{
// 验证分组是否存在(如果指定了分组)
if
len
(
req
.
GroupIDs
)
>
0
{
for
_
,
groupID
:=
range
req
.
GroupIDs
{
_
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
groupID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
fmt
.
Errorf
(
"group %d not found"
,
groupID
)
}
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
}
}
// 创建账号
account
:=
&
model
.
Account
{
Name
:
req
.
Name
,
Platform
:
req
.
Platform
,
Type
:
req
.
Type
,
Credentials
:
req
.
Credentials
,
Extra
:
req
.
Extra
,
ProxyID
:
req
.
ProxyID
,
Concurrency
:
req
.
Concurrency
,
Priority
:
req
.
Priority
,
Status
:
model
.
StatusActive
,
}
if
err
:=
s
.
accountRepo
.
Create
(
ctx
,
account
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create account: %w"
,
err
)
}
// 绑定分组
if
len
(
req
.
GroupIDs
)
>
0
{
if
err
:=
s
.
accountRepo
.
BindGroups
(
ctx
,
account
.
ID
,
req
.
GroupIDs
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"bind groups: %w"
,
err
)
}
}
return
account
,
nil
}
// GetByID 根据ID获取账号
func
(
s
*
AccountService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrAccountNotFound
}
return
nil
,
fmt
.
Errorf
(
"get account: %w"
,
err
)
}
return
account
,
nil
}
// List 获取账号列表
func
(
s
*
AccountService
)
List
(
ctx
context
.
Context
,
params
repository
.
PaginationParams
)
([]
model
.
Account
,
*
repository
.
PaginationResult
,
error
)
{
accounts
,
pagination
,
err
:=
s
.
accountRepo
.
List
(
ctx
,
params
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list accounts: %w"
,
err
)
}
return
accounts
,
pagination
,
nil
}
// ListByPlatform 根据平台获取账号列表
func
(
s
*
AccountService
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Account
,
error
)
{
accounts
,
err
:=
s
.
accountRepo
.
ListByPlatform
(
ctx
,
platform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list accounts by platform: %w"
,
err
)
}
return
accounts
,
nil
}
// ListByGroup 根据分组获取账号列表
func
(
s
*
AccountService
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
model
.
Account
,
error
)
{
accounts
,
err
:=
s
.
accountRepo
.
ListByGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list accounts by group: %w"
,
err
)
}
return
accounts
,
nil
}
// Update 更新账号
func
(
s
*
AccountService
)
Update
(
ctx
context
.
Context
,
id
int64
,
req
UpdateAccountRequest
)
(
*
model
.
Account
,
error
)
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrAccountNotFound
}
return
nil
,
fmt
.
Errorf
(
"get account: %w"
,
err
)
}
// 更新字段
if
req
.
Name
!=
nil
{
account
.
Name
=
*
req
.
Name
}
if
req
.
Credentials
!=
nil
{
account
.
Credentials
=
*
req
.
Credentials
}
if
req
.
Extra
!=
nil
{
account
.
Extra
=
*
req
.
Extra
}
if
req
.
ProxyID
!=
nil
{
account
.
ProxyID
=
req
.
ProxyID
}
if
req
.
Concurrency
!=
nil
{
account
.
Concurrency
=
*
req
.
Concurrency
}
if
req
.
Priority
!=
nil
{
account
.
Priority
=
*
req
.
Priority
}
if
req
.
Status
!=
nil
{
account
.
Status
=
*
req
.
Status
}
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update account: %w"
,
err
)
}
// 更新分组绑定
if
req
.
GroupIDs
!=
nil
{
// 验证分组是否存在
for
_
,
groupID
:=
range
*
req
.
GroupIDs
{
_
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
groupID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
fmt
.
Errorf
(
"group %d not found"
,
groupID
)
}
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
}
if
err
:=
s
.
accountRepo
.
BindGroups
(
ctx
,
account
.
ID
,
*
req
.
GroupIDs
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"bind groups: %w"
,
err
)
}
}
return
account
,
nil
}
// Delete 删除账号
func
(
s
*
AccountService
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
// 检查账号是否存在
_
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
ErrAccountNotFound
}
return
fmt
.
Errorf
(
"get account: %w"
,
err
)
}
if
err
:=
s
.
accountRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete account: %w"
,
err
)
}
return
nil
}
// UpdateStatus 更新账号状态
func
(
s
*
AccountService
)
UpdateStatus
(
ctx
context
.
Context
,
id
int64
,
status
string
,
errorMessage
string
)
error
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
ErrAccountNotFound
}
return
fmt
.
Errorf
(
"get account: %w"
,
err
)
}
account
.
Status
=
status
account
.
ErrorMessage
=
errorMessage
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update account: %w"
,
err
)
}
return
nil
}
// UpdateLastUsed 更新最后使用时间
func
(
s
*
AccountService
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
if
err
:=
s
.
accountRepo
.
UpdateLastUsed
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update last used: %w"
,
err
)
}
return
nil
}
// GetCredential 获取账号凭证(安全访问)
func
(
s
*
AccountService
)
GetCredential
(
ctx
context
.
Context
,
id
int64
,
key
string
)
(
string
,
error
)
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
""
,
ErrAccountNotFound
}
return
""
,
fmt
.
Errorf
(
"get account: %w"
,
err
)
}
return
account
.
GetCredential
(
key
),
nil
}
// TestCredentials 测试账号凭证是否有效(需要实现具体平台的测试逻辑)
func
(
s
*
AccountService
)
TestCredentials
(
ctx
context
.
Context
,
id
int64
)
error
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
ErrAccountNotFound
}
return
fmt
.
Errorf
(
"get account: %w"
,
err
)
}
// 根据平台执行不同的测试逻辑
switch
account
.
Platform
{
case
model
.
PlatformAnthropic
:
// TODO: 测试Anthropic API凭证
return
nil
case
model
.
PlatformOpenAI
:
// TODO: 测试OpenAI API凭证
return
nil
case
model
.
PlatformGemini
:
// TODO: 测试Gemini API凭证
return
nil
default
:
return
fmt
.
Errorf
(
"unsupported platform: %s"
,
account
.
Platform
)
}
}
backend/internal/service/account_test_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"bufio"
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"sub2api/internal/repository"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const
(
testClaudeAPIURL
=
"https://api.anthropic.com/v1/messages"
testModel
=
"claude-sonnet-4-5-20250929"
)
// TestEvent represents a SSE event for account testing
type
TestEvent
struct
{
Type
string
`json:"type"`
Text
string
`json:"text,omitempty"`
Model
string
`json:"model,omitempty"`
Success
bool
`json:"success,omitempty"`
Error
string
`json:"error,omitempty"`
}
// AccountTestService handles account testing operations
type
AccountTestService
struct
{
repos
*
repository
.
Repositories
oauthService
*
OAuthService
httpClient
*
http
.
Client
}
// NewAccountTestService creates a new AccountTestService
func
NewAccountTestService
(
repos
*
repository
.
Repositories
,
oauthService
*
OAuthService
)
*
AccountTestService
{
return
&
AccountTestService
{
repos
:
repos
,
oauthService
:
oauthService
,
httpClient
:
&
http
.
Client
{
Timeout
:
60
*
time
.
Second
,
},
}
}
// generateSessionString generates a Claude Code style session string
func
generateSessionString
()
string
{
bytes
:=
make
([]
byte
,
32
)
rand
.
Read
(
bytes
)
hex64
:=
hex
.
EncodeToString
(
bytes
)
sessionUUID
:=
uuid
.
New
()
.
String
()
return
fmt
.
Sprintf
(
"user_%s_account__session_%s"
,
hex64
,
sessionUUID
)
}
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts
func
createTestPayload
()
map
[
string
]
interface
{}
{
return
map
[
string
]
interface
{}{
"model"
:
testModel
,
"messages"
:
[]
map
[
string
]
interface
{}{
{
"role"
:
"user"
,
"content"
:
[]
map
[
string
]
interface
{}{
{
"type"
:
"text"
,
"text"
:
"hi"
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
,
},
},
},
},
},
"system"
:
[]
map
[
string
]
interface
{}{
{
"type"
:
"text"
,
"text"
:
"You are Claude Code, Anthropic's official CLI for Claude."
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
,
},
},
},
"metadata"
:
map
[
string
]
string
{
"user_id"
:
generateSessionString
(),
},
"max_tokens"
:
1024
,
"temperature"
:
1
,
"stream"
:
true
,
}
}
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
func
createApiKeyTestPayload
(
model
string
)
map
[
string
]
interface
{}
{
return
map
[
string
]
interface
{}{
"model"
:
model
,
"messages"
:
[]
map
[
string
]
interface
{}{
{
"role"
:
"user"
,
"content"
:
"hi"
,
},
},
"max_tokens"
:
1024
,
"stream"
:
true
,
}
}
// TestAccountConnection tests an account's connection by sending a test request
func
(
s
*
AccountTestService
)
TestAccountConnection
(
c
*
gin
.
Context
,
accountID
int64
)
error
{
ctx
:=
c
.
Request
.
Context
()
// Get account
account
,
err
:=
s
.
repos
.
Account
.
GetByID
(
ctx
,
accountID
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
"Account not found"
)
}
// Determine authentication method based on account type
var
authToken
string
var
authType
string
// "bearer" for OAuth, "apikey" for API Key
var
apiURL
string
if
account
.
IsOAuth
()
{
// OAuth or Setup Token account
authType
=
"bearer"
apiURL
=
testClaudeAPIURL
authToken
=
account
.
GetCredential
(
"access_token"
)
if
authToken
==
""
{
return
s
.
sendErrorAndEnd
(
c
,
"No access token available"
)
}
// Check if token needs refresh
needRefresh
:=
false
if
expiresAtStr
:=
account
.
GetCredential
(
"expires_at"
);
expiresAtStr
!=
""
{
expiresAt
,
err
:=
strconv
.
ParseInt
(
expiresAtStr
,
10
,
64
)
if
err
==
nil
&&
time
.
Now
()
.
Unix
()
+
300
>
expiresAt
{
// 5 minute buffer
needRefresh
=
true
}
}
if
needRefresh
&&
s
.
oauthService
!=
nil
{
tokenInfo
,
err
:=
s
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Failed to refresh token: %s"
,
err
.
Error
()))
}
authToken
=
tokenInfo
.
AccessToken
}
}
else
if
account
.
Type
==
"apikey"
{
// API Key account
authType
=
"apikey"
authToken
=
account
.
GetCredential
(
"api_key"
)
if
authToken
==
""
{
return
s
.
sendErrorAndEnd
(
c
,
"No API key available"
)
}
// Get base URL (use default if not set)
apiURL
=
account
.
GetBaseURL
()
if
apiURL
==
""
{
apiURL
=
"https://api.anthropic.com"
}
// Append /v1/messages endpoint
apiURL
=
strings
.
TrimSuffix
(
apiURL
,
"/"
)
+
"/v1/messages"
}
else
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Unsupported account type: %s"
,
account
.
Type
))
}
// Set SSE headers
c
.
Writer
.
Header
()
.
Set
(
"Content-Type"
,
"text/event-stream"
)
c
.
Writer
.
Header
()
.
Set
(
"Cache-Control"
,
"no-cache"
)
c
.
Writer
.
Header
()
.
Set
(
"Connection"
,
"keep-alive"
)
c
.
Writer
.
Header
()
.
Set
(
"X-Accel-Buffering"
,
"no"
)
c
.
Writer
.
Flush
()
// Create test request payload
var
payload
map
[
string
]
interface
{}
var
actualModel
string
if
authType
==
"apikey"
{
// Use simpler payload for API Key (without Claude Code specific fields)
// Apply model mapping if configured
actualModel
=
account
.
GetMappedModel
(
testModel
)
payload
=
createApiKeyTestPayload
(
actualModel
)
}
else
{
actualModel
=
testModel
payload
=
createTestPayload
()
}
payloadBytes
,
_
:=
json
.
Marshal
(
payload
)
// Send test_start event with model info
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_start"
,
Model
:
actualModel
})
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
apiURL
,
bytes
.
NewReader
(
payloadBytes
))
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
"Failed to create request"
)
}
// Set headers based on auth type
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
if
authType
==
"bearer"
{
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
authToken
)
req
.
Header
.
Set
(
"anthropic-beta"
,
"prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19"
)
}
else
{
// API Key uses x-api-key header
req
.
Header
.
Set
(
"x-api-key"
,
authToken
)
}
// Configure proxy if account has one
transport
:=
http
.
DefaultTransport
.
(
*
http
.
Transport
)
.
Clone
()
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
:=
account
.
Proxy
.
URL
()
if
proxyURL
!=
""
{
if
parsedURL
,
err
:=
url
.
Parse
(
proxyURL
);
err
==
nil
{
transport
.
Proxy
=
http
.
ProxyURL
(
parsedURL
)
}
}
}
client
:=
&
http
.
Client
{
Transport
:
transport
,
Timeout
:
60
*
time
.
Second
,
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Request failed: %s"
,
err
.
Error
()))
}
defer
resp
.
Body
.
Close
()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"API returned %d: %s"
,
resp
.
StatusCode
,
string
(
body
)))
}
// Process SSE stream
return
s
.
processStream
(
c
,
resp
.
Body
)
}
// processStream processes the SSE stream from Claude API
func
(
s
*
AccountTestService
)
processStream
(
c
*
gin
.
Context
,
body
io
.
Reader
)
error
{
reader
:=
bufio
.
NewReader
(
body
)
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
if
err
!=
nil
{
if
err
==
io
.
EOF
{
// Stream ended, send complete event
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
}
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Stream read error: %s"
,
err
.
Error
()))
}
line
=
strings
.
TrimSpace
(
line
)
if
line
==
""
||
!
strings
.
HasPrefix
(
line
,
"data: "
)
{
continue
}
jsonStr
:=
strings
.
TrimPrefix
(
line
,
"data: "
)
if
jsonStr
==
"[DONE]"
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
}
var
data
map
[
string
]
interface
{}
if
err
:=
json
.
Unmarshal
([]
byte
(
jsonStr
),
&
data
);
err
!=
nil
{
continue
}
eventType
,
_
:=
data
[
"type"
]
.
(
string
)
switch
eventType
{
case
"content_block_delta"
:
if
delta
,
ok
:=
data
[
"delta"
]
.
(
map
[
string
]
interface
{});
ok
{
if
text
,
ok
:=
delta
[
"text"
]
.
(
string
);
ok
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
text
})
}
}
case
"message_stop"
:
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
case
"error"
:
errorMsg
:=
"Unknown error"
if
errData
,
ok
:=
data
[
"error"
]
.
(
map
[
string
]
interface
{});
ok
{
if
msg
,
ok
:=
errData
[
"message"
]
.
(
string
);
ok
{
errorMsg
=
msg
}
}
return
s
.
sendErrorAndEnd
(
c
,
errorMsg
)
}
}
}
// sendEvent sends a SSE event to the client
func
(
s
*
AccountTestService
)
sendEvent
(
c
*
gin
.
Context
,
event
TestEvent
)
{
eventJSON
,
_
:=
json
.
Marshal
(
event
)
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
eventJSON
)
c
.
Writer
.
Flush
()
}
// sendErrorAndEnd sends an error event and ends the stream
func
(
s
*
AccountTestService
)
sendErrorAndEnd
(
c
*
gin
.
Context
,
errorMsg
string
)
error
{
log
.
Printf
(
"Account test error: %s"
,
errorMsg
)
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"error"
,
Error
:
errorMsg
})
return
fmt
.
Errorf
(
errorMsg
)
}
backend/internal/service/account_usage_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"sync"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
)
// usageCache 用于缓存usage数据
type
usageCache
struct
{
data
*
UsageInfo
timestamp
time
.
Time
}
var
(
usageCacheMap
=
sync
.
Map
{}
cacheTTL
=
10
*
time
.
Minute
)
// WindowStats 窗口期统计
type
WindowStats
struct
{
Requests
int64
`json:"requests"`
Tokens
int64
`json:"tokens"`
Cost
float64
`json:"cost"`
}
// UsageProgress 使用量进度
type
UsageProgress
struct
{
Utilization
float64
`json:"utilization"`
// 使用率百分比 (0-100+,100表示100%)
ResetsAt
*
time
.
Time
`json:"resets_at"`
// 重置时间
RemainingSeconds
int
`json:"remaining_seconds"`
// 距重置剩余秒数
WindowStats
*
WindowStats
`json:"window_stats,omitempty"`
// 窗口期统计(从窗口开始到当前的使用量)
}
// UsageInfo 账号使用量信息
type
UsageInfo
struct
{
UpdatedAt
*
time
.
Time
`json:"updated_at,omitempty"`
// 更新时间
FiveHour
*
UsageProgress
`json:"five_hour"`
// 5小时窗口
SevenDay
*
UsageProgress
`json:"seven_day,omitempty"`
// 7天窗口
SevenDaySonnet
*
UsageProgress
`json:"seven_day_sonnet,omitempty"`
// 7天Sonnet窗口
}
// ClaudeUsageResponse Anthropic API返回的usage结构
type
ClaudeUsageResponse
struct
{
FiveHour
struct
{
Utilization
float64
`json:"utilization"`
ResetsAt
string
`json:"resets_at"`
}
`json:"five_hour"`
SevenDay
struct
{
Utilization
float64
`json:"utilization"`
ResetsAt
string
`json:"resets_at"`
}
`json:"seven_day"`
SevenDaySonnet
struct
{
Utilization
float64
`json:"utilization"`
ResetsAt
string
`json:"resets_at"`
}
`json:"seven_day_sonnet"`
}
// AccountUsageService 账号使用量查询服务
type
AccountUsageService
struct
{
repos
*
repository
.
Repositories
oauthService
*
OAuthService
httpClient
*
http
.
Client
}
// NewAccountUsageService 创建AccountUsageService实例
func
NewAccountUsageService
(
repos
*
repository
.
Repositories
,
oauthService
*
OAuthService
)
*
AccountUsageService
{
return
&
AccountUsageService
{
repos
:
repos
,
oauthService
:
oauthService
,
httpClient
:
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
,
},
}
}
// GetUsage 获取账号使用量
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),缓存10分钟
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
// API Key账号: 不支持usage查询
func
(
s
*
AccountUsageService
)
GetUsage
(
ctx
context
.
Context
,
accountID
int64
)
(
*
UsageInfo
,
error
)
{
account
,
err
:=
s
.
repos
.
Account
.
GetByID
(
ctx
,
accountID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get account failed: %w"
,
err
)
}
// 只有oauth类型账号可以通过API获取usage(有profile scope)
if
account
.
CanGetUsage
()
{
// 检查缓存
if
cached
,
ok
:=
usageCacheMap
.
Load
(
accountID
);
ok
{
cache
:=
cached
.
(
*
usageCache
)
if
time
.
Since
(
cache
.
timestamp
)
<
cacheTTL
{
return
cache
.
data
,
nil
}
}
// 从API获取数据
usage
,
err
:=
s
.
fetchOAuthUsage
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
// 添加5h窗口统计数据
s
.
addWindowStats
(
ctx
,
account
,
usage
)
// 缓存结果
usageCacheMap
.
Store
(
accountID
,
&
usageCache
{
data
:
usage
,
timestamp
:
time
.
Now
(),
})
return
usage
,
nil
}
// Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API)
if
account
.
Type
==
model
.
AccountTypeSetupToken
{
usage
:=
s
.
estimateSetupTokenUsage
(
account
)
// 添加窗口统计
s
.
addWindowStats
(
ctx
,
account
,
usage
)
return
usage
,
nil
}
// API Key账号不支持usage查询
return
nil
,
fmt
.
Errorf
(
"account type %s does not support usage query"
,
account
.
Type
)
}
// addWindowStats 为usage数据添加窗口期统计
func
(
s
*
AccountUsageService
)
addWindowStats
(
ctx
context
.
Context
,
account
*
model
.
Account
,
usage
*
UsageInfo
)
{
if
usage
.
FiveHour
==
nil
{
return
}
// 使用session_window_start作为统计起始时间
var
startTime
time
.
Time
if
account
.
SessionWindowStart
!=
nil
{
startTime
=
*
account
.
SessionWindowStart
}
else
{
// 如果没有窗口信息,使用5小时前作为默认
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
s
.
repos
.
UsageLog
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
if
err
!=
nil
{
log
.
Printf
(
"Failed to get window stats for account %d: %v"
,
account
.
ID
,
err
)
return
}
usage
.
FiveHour
.
WindowStats
=
&
WindowStats
{
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
}
}
// GetTodayStats 获取账号今日统计
func
(
s
*
AccountUsageService
)
GetTodayStats
(
ctx
context
.
Context
,
accountID
int64
)
(
*
WindowStats
,
error
)
{
stats
,
err
:=
s
.
repos
.
UsageLog
.
GetAccountTodayStats
(
ctx
,
accountID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get today stats failed: %w"
,
err
)
}
return
&
WindowStats
{
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
},
nil
}
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func
(
s
*
AccountUsageService
)
fetchOAuthUsage
(
ctx
context
.
Context
,
account
*
model
.
Account
)
(
*
UsageInfo
,
error
)
{
// 获取access token(从credentials中获取)
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
return
nil
,
fmt
.
Errorf
(
"no access token available"
)
}
// 获取代理配置
transport
:=
http
.
DefaultTransport
.
(
*
http
.
Transport
)
.
Clone
()
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
:=
account
.
Proxy
.
URL
()
if
proxyURL
!=
""
{
if
parsedURL
,
err
:=
url
.
Parse
(
proxyURL
);
err
==
nil
{
transport
.
Proxy
=
http
.
ProxyURL
(
parsedURL
)
}
}
}
client
:=
&
http
.
Client
{
Transport
:
transport
,
Timeout
:
30
*
time
.
Second
,
}
// 构建请求
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://api.anthropic.com/api/oauth/usage"
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create request failed: %w"
,
err
)
}
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
req
.
Header
.
Set
(
"anthropic-beta"
,
"oauth-2025-04-20"
)
// 发送请求
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request failed: %w"
,
err
)
}
defer
resp
.
Body
.
Close
()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
return
nil
,
fmt
.
Errorf
(
"API returned status %d: %s"
,
resp
.
StatusCode
,
string
(
body
))
}
// 解析响应
var
usageResp
ClaudeUsageResponse
if
err
:=
json
.
NewDecoder
(
resp
.
Body
)
.
Decode
(
&
usageResp
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"decode response failed: %w"
,
err
)
}
// 转换为UsageInfo
now
:=
time
.
Now
()
return
s
.
buildUsageInfo
(
&
usageResp
,
&
now
),
nil
}
// parseTime 尝试多种格式解析时间
func
parseTime
(
s
string
)
(
time
.
Time
,
error
)
{
formats
:=
[]
string
{
time
.
RFC3339
,
time
.
RFC3339Nano
,
"2006-01-02T15:04:05Z"
,
"2006-01-02T15:04:05.000Z"
,
}
for
_
,
format
:=
range
formats
{
if
t
,
err
:=
time
.
Parse
(
format
,
s
);
err
==
nil
{
return
t
,
nil
}
}
return
time
.
Time
{},
fmt
.
Errorf
(
"unable to parse time: %s"
,
s
)
}
// buildUsageInfo 构建UsageInfo
func
(
s
*
AccountUsageService
)
buildUsageInfo
(
resp
*
ClaudeUsageResponse
,
updatedAt
*
time
.
Time
)
*
UsageInfo
{
info
:=
&
UsageInfo
{
UpdatedAt
:
updatedAt
,
}
// 5小时窗口
if
resp
.
FiveHour
.
ResetsAt
!=
""
{
if
fiveHourReset
,
err
:=
parseTime
(
resp
.
FiveHour
.
ResetsAt
);
err
==
nil
{
info
.
FiveHour
=
&
UsageProgress
{
Utilization
:
resp
.
FiveHour
.
Utilization
,
ResetsAt
:
&
fiveHourReset
,
RemainingSeconds
:
int
(
time
.
Until
(
fiveHourReset
)
.
Seconds
()),
}
}
else
{
log
.
Printf
(
"Failed to parse FiveHour.ResetsAt: %s, error: %v"
,
resp
.
FiveHour
.
ResetsAt
,
err
)
// 即使解析失败也返回utilization
info
.
FiveHour
=
&
UsageProgress
{
Utilization
:
resp
.
FiveHour
.
Utilization
,
}
}
}
// 7天窗口
if
resp
.
SevenDay
.
ResetsAt
!=
""
{
if
sevenDayReset
,
err
:=
parseTime
(
resp
.
SevenDay
.
ResetsAt
);
err
==
nil
{
info
.
SevenDay
=
&
UsageProgress
{
Utilization
:
resp
.
SevenDay
.
Utilization
,
ResetsAt
:
&
sevenDayReset
,
RemainingSeconds
:
int
(
time
.
Until
(
sevenDayReset
)
.
Seconds
()),
}
}
else
{
log
.
Printf
(
"Failed to parse SevenDay.ResetsAt: %s, error: %v"
,
resp
.
SevenDay
.
ResetsAt
,
err
)
info
.
SevenDay
=
&
UsageProgress
{
Utilization
:
resp
.
SevenDay
.
Utilization
,
}
}
}
// 7天Sonnet窗口
if
resp
.
SevenDaySonnet
.
ResetsAt
!=
""
{
if
sonnetReset
,
err
:=
parseTime
(
resp
.
SevenDaySonnet
.
ResetsAt
);
err
==
nil
{
info
.
SevenDaySonnet
=
&
UsageProgress
{
Utilization
:
resp
.
SevenDaySonnet
.
Utilization
,
ResetsAt
:
&
sonnetReset
,
RemainingSeconds
:
int
(
time
.
Until
(
sonnetReset
)
.
Seconds
()),
}
}
else
{
log
.
Printf
(
"Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v"
,
resp
.
SevenDaySonnet
.
ResetsAt
,
err
)
info
.
SevenDaySonnet
=
&
UsageProgress
{
Utilization
:
resp
.
SevenDaySonnet
.
Utilization
,
}
}
}
return
info
}
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
func
(
s
*
AccountUsageService
)
estimateSetupTokenUsage
(
account
*
model
.
Account
)
*
UsageInfo
{
info
:=
&
UsageInfo
{}
// 如果有session_window信息
if
account
.
SessionWindowEnd
!=
nil
{
remaining
:=
int
(
time
.
Until
(
*
account
.
SessionWindowEnd
)
.
Seconds
())
if
remaining
<
0
{
remaining
=
0
}
// 根据状态估算使用率 (百分比形式,100 = 100%)
var
utilization
float64
switch
account
.
SessionWindowStatus
{
case
"rejected"
:
utilization
=
100.0
case
"allowed_warning"
:
utilization
=
80.0
default
:
utilization
=
0.0
}
info
.
FiveHour
=
&
UsageProgress
{
Utilization
:
utilization
,
ResetsAt
:
account
.
SessionWindowEnd
,
RemainingSeconds
:
remaining
,
}
}
else
{
// 没有窗口信息,返回空数据
info
.
FiveHour
=
&
UsageProgress
{
Utilization
:
0
,
RemainingSeconds
:
0
,
}
}
// Setup Token无法获取7d数据
return
info
}
backend/internal/service/admin_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"golang.org/x/net/proxy"
"gorm.io/gorm"
)
// AdminService interface defines admin management operations
type
AdminService
interface
{
// User management
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
status
,
role
,
search
string
)
([]
model
.
User
,
int64
,
error
)
GetUser
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
User
,
error
)
CreateUser
(
ctx
context
.
Context
,
input
*
CreateUserInput
)
(
*
model
.
User
,
error
)
UpdateUser
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateUserInput
)
(
*
model
.
User
,
error
)
DeleteUser
(
ctx
context
.
Context
,
id
int64
)
error
UpdateUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
,
operation
string
)
(
*
model
.
User
,
error
)
GetUserAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
)
([]
model
.
ApiKey
,
int64
,
error
)
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
interface
{},
error
)
// Group management
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
model
.
Group
,
int64
,
error
)
GetAllGroups
(
ctx
context
.
Context
)
([]
model
.
Group
,
error
)
GetAllGroupsByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Group
,
error
)
GetGroup
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Group
,
error
)
CreateGroup
(
ctx
context
.
Context
,
input
*
CreateGroupInput
)
(
*
model
.
Group
,
error
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
model
.
Group
,
error
)
DeleteGroup
(
ctx
context
.
Context
,
id
int64
)
error
GetGroupAPIKeys
(
ctx
context
.
Context
,
groupID
int64
,
page
,
pageSize
int
)
([]
model
.
ApiKey
,
int64
,
error
)
// Account management
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
)
([]
model
.
Account
,
int64
,
error
)
GetAccount
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
CreateAccount
(
ctx
context
.
Context
,
input
*
CreateAccountInput
)
(
*
model
.
Account
,
error
)
UpdateAccount
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateAccountInput
)
(
*
model
.
Account
,
error
)
DeleteAccount
(
ctx
context
.
Context
,
id
int64
)
error
RefreshAccountCredentials
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
ClearAccountError
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
SetAccountSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
(
*
model
.
Account
,
error
)
// Proxy management
ListProxies
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
)
([]
model
.
Proxy
,
int64
,
error
)
GetAllProxies
(
ctx
context
.
Context
)
([]
model
.
Proxy
,
error
)
GetAllProxiesWithAccountCount
(
ctx
context
.
Context
)
([]
model
.
ProxyWithAccountCount
,
error
)
GetProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Proxy
,
error
)
CreateProxy
(
ctx
context
.
Context
,
input
*
CreateProxyInput
)
(
*
model
.
Proxy
,
error
)
UpdateProxy
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateProxyInput
)
(
*
model
.
Proxy
,
error
)
DeleteProxy
(
ctx
context
.
Context
,
id
int64
)
error
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
,
page
,
pageSize
int
)
([]
model
.
Account
,
int64
,
error
)
CheckProxyExists
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
TestProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
ProxyTestResult
,
error
)
// Redeem code management
ListRedeemCodes
(
ctx
context
.
Context
,
page
,
pageSize
int
,
codeType
,
status
,
search
string
)
([]
model
.
RedeemCode
,
int64
,
error
)
GetRedeemCode
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
RedeemCode
,
error
)
GenerateRedeemCodes
(
ctx
context
.
Context
,
input
*
GenerateRedeemCodesInput
)
([]
model
.
RedeemCode
,
error
)
DeleteRedeemCode
(
ctx
context
.
Context
,
id
int64
)
error
BatchDeleteRedeemCodes
(
ctx
context
.
Context
,
ids
[]
int64
)
(
int64
,
error
)
ExpireRedeemCode
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
RedeemCode
,
error
)
}
// Input types for admin operations
type
CreateUserInput
struct
{
Email
string
Password
string
Balance
float64
Concurrency
int
AllowedGroups
[]
int64
}
type
UpdateUserInput
struct
{
Email
string
Password
string
Balance
*
float64
// 使用指针区分"未提供"和"设置为0"
Concurrency
*
int
// 使用指针区分"未提供"和"设置为0"
Status
string
AllowedGroups
*
[]
int64
// 使用指针区分"未提供"和"设置为空数组"
}
type
CreateGroupInput
struct
{
Name
string
Description
string
Platform
string
RateMultiplier
float64
IsExclusive
bool
SubscriptionType
string
// standard/subscription
DailyLimitUSD
*
float64
// 日限额 (USD)
WeeklyLimitUSD
*
float64
// 周限额 (USD)
MonthlyLimitUSD
*
float64
// 月限额 (USD)
}
type
UpdateGroupInput
struct
{
Name
string
Description
string
Platform
string
RateMultiplier
*
float64
// 使用指针以支持设置为0
IsExclusive
*
bool
Status
string
SubscriptionType
string
// standard/subscription
DailyLimitUSD
*
float64
// 日限额 (USD)
WeeklyLimitUSD
*
float64
// 周限额 (USD)
MonthlyLimitUSD
*
float64
// 月限额 (USD)
}
type
CreateAccountInput
struct
{
Name
string
Platform
string
Type
string
Credentials
map
[
string
]
interface
{}
Extra
map
[
string
]
interface
{}
ProxyID
*
int64
Concurrency
int
Priority
int
GroupIDs
[]
int64
}
type
UpdateAccountInput
struct
{
Name
string
Type
string
// Account type: oauth, setup-token, apikey
Credentials
map
[
string
]
interface
{}
Extra
map
[
string
]
interface
{}
ProxyID
*
int64
Concurrency
*
int
// 使用指针区分"未提供"和"设置为0"
Priority
*
int
// 使用指针区分"未提供"和"设置为0"
Status
string
GroupIDs
*
[]
int64
}
type
CreateProxyInput
struct
{
Name
string
Protocol
string
Host
string
Port
int
Username
string
Password
string
}
type
UpdateProxyInput
struct
{
Name
string
Protocol
string
Host
string
Port
int
Username
string
Password
string
Status
string
}
type
GenerateRedeemCodesInput
struct
{
Count
int
Type
string
Value
float64
GroupID
*
int64
// 订阅类型专用:关联的分组ID
ValidityDays
int
// 订阅类型专用:有效天数
}
// ProxyTestResult represents the result of testing a proxy
type
ProxyTestResult
struct
{
Success
bool
`json:"success"`
Message
string
`json:"message"`
LatencyMs
int64
`json:"latency_ms,omitempty"`
IPAddress
string
`json:"ip_address,omitempty"`
City
string
`json:"city,omitempty"`
Region
string
`json:"region,omitempty"`
Country
string
`json:"country,omitempty"`
}
// adminServiceImpl implements AdminService
type
adminServiceImpl
struct
{
userRepo
*
repository
.
UserRepository
groupRepo
*
repository
.
GroupRepository
accountRepo
*
repository
.
AccountRepository
proxyRepo
*
repository
.
ProxyRepository
apiKeyRepo
*
repository
.
ApiKeyRepository
redeemCodeRepo
*
repository
.
RedeemCodeRepository
usageLogRepo
*
repository
.
UsageLogRepository
userSubRepo
*
repository
.
UserSubscriptionRepository
billingCacheService
*
BillingCacheService
}
// NewAdminService creates a new AdminService
func
NewAdminService
(
repos
*
repository
.
Repositories
)
AdminService
{
return
&
adminServiceImpl
{
userRepo
:
repos
.
User
,
groupRepo
:
repos
.
Group
,
accountRepo
:
repos
.
Account
,
proxyRepo
:
repos
.
Proxy
,
apiKeyRepo
:
repos
.
ApiKey
,
redeemCodeRepo
:
repos
.
RedeemCode
,
usageLogRepo
:
repos
.
UsageLog
,
userSubRepo
:
repos
.
UserSubscription
,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
// 注意:AdminService是接口,需要类型断言
func
SetAdminServiceBillingCache
(
adminService
AdminService
,
billingCacheService
*
BillingCacheService
)
{
if
impl
,
ok
:=
adminService
.
(
*
adminServiceImpl
);
ok
{
impl
.
billingCacheService
=
billingCacheService
}
}
// User management implementations
func
(
s
*
adminServiceImpl
)
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
status
,
role
,
search
string
)
([]
model
.
User
,
int64
,
error
)
{
params
:=
repository
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
users
,
result
,
err
:=
s
.
userRepo
.
ListWithFilters
(
ctx
,
params
,
status
,
role
,
search
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
return
users
,
result
.
Total
,
nil
}
func
(
s
*
adminServiceImpl
)
GetUser
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
User
,
error
)
{
return
s
.
userRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
CreateUser
(
ctx
context
.
Context
,
input
*
CreateUserInput
)
(
*
model
.
User
,
error
)
{
user
:=
&
model
.
User
{
Email
:
input
.
Email
,
Role
:
"user"
,
// Always create as regular user, never admin
Balance
:
input
.
Balance
,
Concurrency
:
input
.
Concurrency
,
Status
:
model
.
StatusActive
,
}
if
err
:=
user
.
SetPassword
(
input
.
Password
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
user
);
err
!=
nil
{
return
nil
,
err
}
return
user
,
nil
}
func
(
s
*
adminServiceImpl
)
UpdateUser
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateUserInput
)
(
*
model
.
User
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
// Protect admin users: cannot disable admin accounts
if
user
.
Role
==
"admin"
&&
input
.
Status
==
"disabled"
{
return
nil
,
errors
.
New
(
"cannot disable admin user"
)
}
// Track balance and concurrency changes for logging
oldBalance
:=
user
.
Balance
oldConcurrency
:=
user
.
Concurrency
if
input
.
Email
!=
""
{
user
.
Email
=
input
.
Email
}
if
input
.
Password
!=
""
{
if
err
:=
user
.
SetPassword
(
input
.
Password
);
err
!=
nil
{
return
nil
,
err
}
}
// Role is not allowed to be changed via API to prevent privilege escalation
if
input
.
Status
!=
""
{
user
.
Status
=
input
.
Status
}
// 只在指针非 nil 时更新 Balance(支持设置为 0)
if
input
.
Balance
!=
nil
{
user
.
Balance
=
*
input
.
Balance
}
// 只在指针非 nil 时更新 Concurrency(支持设置为任意值)
if
input
.
Concurrency
!=
nil
{
user
.
Concurrency
=
*
input
.
Concurrency
}
// 只在指针非 nil 时更新 AllowedGroups
if
input
.
AllowedGroups
!=
nil
{
user
.
AllowedGroups
=
*
input
.
AllowedGroups
}
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
nil
,
err
}
// 余额变化时失效缓存
if
input
.
Balance
!=
nil
&&
*
input
.
Balance
!=
oldBalance
{
if
s
.
billingCacheService
!=
nil
{
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
s
.
billingCacheService
.
InvalidateUserBalance
(
cacheCtx
,
id
)
}()
}
}
// Create adjustment records for balance/concurrency changes
balanceDiff
:=
user
.
Balance
-
oldBalance
if
balanceDiff
!=
0
{
adjustmentRecord
:=
&
model
.
RedeemCode
{
Code
:
model
.
GenerateRedeemCode
(),
Type
:
model
.
AdjustmentTypeAdminBalance
,
Value
:
balanceDiff
,
Status
:
model
.
StatusUsed
,
UsedBy
:
&
user
.
ID
,
}
now
:=
time
.
Now
()
adjustmentRecord
.
UsedAt
=
&
now
if
err
:=
s
.
redeemCodeRepo
.
Create
(
ctx
,
adjustmentRecord
);
err
!=
nil
{
// Log error but don't fail the update
// The user update has already succeeded
}
}
concurrencyDiff
:=
user
.
Concurrency
-
oldConcurrency
if
concurrencyDiff
!=
0
{
adjustmentRecord
:=
&
model
.
RedeemCode
{
Code
:
model
.
GenerateRedeemCode
(),
Type
:
model
.
AdjustmentTypeAdminConcurrency
,
Value
:
float64
(
concurrencyDiff
),
Status
:
model
.
StatusUsed
,
UsedBy
:
&
user
.
ID
,
}
now
:=
time
.
Now
()
adjustmentRecord
.
UsedAt
=
&
now
if
err
:=
s
.
redeemCodeRepo
.
Create
(
ctx
,
adjustmentRecord
);
err
!=
nil
{
// Log error but don't fail the update
// The user update has already succeeded
}
}
return
user
,
nil
}
func
(
s
*
adminServiceImpl
)
DeleteUser
(
ctx
context
.
Context
,
id
int64
)
error
{
// Protect admin users: cannot delete admin accounts
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
user
.
Role
==
"admin"
{
return
errors
.
New
(
"cannot delete admin user"
)
}
return
s
.
userRepo
.
Delete
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
UpdateUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
,
operation
string
)
(
*
model
.
User
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
err
}
switch
operation
{
case
"set"
:
user
.
Balance
=
balance
case
"add"
:
user
.
Balance
+=
balance
case
"subtract"
:
user
.
Balance
-=
balance
}
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
nil
,
err
}
// 失效余额缓存
if
s
.
billingCacheService
!=
nil
{
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
s
.
billingCacheService
.
InvalidateUserBalance
(
cacheCtx
,
userID
)
}()
}
return
user
,
nil
}
func
(
s
*
adminServiceImpl
)
GetUserAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
page
,
pageSize
int
)
([]
model
.
ApiKey
,
int64
,
error
)
{
params
:=
repository
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
keys
,
result
,
err
:=
s
.
apiKeyRepo
.
ListByUserID
(
ctx
,
userID
,
params
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
return
keys
,
result
.
Total
,
nil
}
func
(
s
*
adminServiceImpl
)
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
interface
{},
error
)
{
// Return mock data for now
return
map
[
string
]
interface
{}{
"period"
:
period
,
"total_requests"
:
0
,
"total_cost"
:
0.0
,
"total_tokens"
:
0
,
"avg_duration_ms"
:
0
,
},
nil
}
// Group management implementations
func
(
s
*
adminServiceImpl
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
model
.
Group
,
int64
,
error
)
{
params
:=
repository
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
groups
,
result
,
err
:=
s
.
groupRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
status
,
isExclusive
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
return
groups
,
result
.
Total
,
nil
}
func
(
s
*
adminServiceImpl
)
GetAllGroups
(
ctx
context
.
Context
)
([]
model
.
Group
,
error
)
{
return
s
.
groupRepo
.
ListActive
(
ctx
)
}
func
(
s
*
adminServiceImpl
)
GetAllGroupsByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Group
,
error
)
{
return
s
.
groupRepo
.
ListActiveByPlatform
(
ctx
,
platform
)
}
func
(
s
*
adminServiceImpl
)
GetGroup
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Group
,
error
)
{
return
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
CreateGroup
(
ctx
context
.
Context
,
input
*
CreateGroupInput
)
(
*
model
.
Group
,
error
)
{
platform
:=
input
.
Platform
if
platform
==
""
{
platform
=
model
.
PlatformAnthropic
}
subscriptionType
:=
input
.
SubscriptionType
if
subscriptionType
==
""
{
subscriptionType
=
model
.
SubscriptionTypeStandard
}
group
:=
&
model
.
Group
{
Name
:
input
.
Name
,
Description
:
input
.
Description
,
Platform
:
platform
,
RateMultiplier
:
input
.
RateMultiplier
,
IsExclusive
:
input
.
IsExclusive
,
Status
:
model
.
StatusActive
,
SubscriptionType
:
subscriptionType
,
DailyLimitUSD
:
input
.
DailyLimitUSD
,
WeeklyLimitUSD
:
input
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
input
.
MonthlyLimitUSD
,
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
}
return
group
,
nil
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
model
.
Group
,
error
)
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
input
.
Name
!=
""
{
group
.
Name
=
input
.
Name
}
if
input
.
Description
!=
""
{
group
.
Description
=
input
.
Description
}
if
input
.
Platform
!=
""
{
group
.
Platform
=
input
.
Platform
}
if
input
.
RateMultiplier
!=
nil
{
group
.
RateMultiplier
=
*
input
.
RateMultiplier
}
if
input
.
IsExclusive
!=
nil
{
group
.
IsExclusive
=
*
input
.
IsExclusive
}
if
input
.
Status
!=
""
{
group
.
Status
=
input
.
Status
}
// 订阅相关字段
if
input
.
SubscriptionType
!=
""
{
group
.
SubscriptionType
=
input
.
SubscriptionType
}
// 限额字段支持设置为nil(清除限额)或具体值
if
input
.
DailyLimitUSD
!=
nil
{
group
.
DailyLimitUSD
=
input
.
DailyLimitUSD
}
if
input
.
WeeklyLimitUSD
!=
nil
{
group
.
WeeklyLimitUSD
=
input
.
WeeklyLimitUSD
}
if
input
.
MonthlyLimitUSD
!=
nil
{
group
.
MonthlyLimitUSD
=
input
.
MonthlyLimitUSD
}
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
}
return
group
,
nil
}
func
(
s
*
adminServiceImpl
)
DeleteGroup
(
ctx
context
.
Context
,
id
int64
)
error
{
// 先获取分组信息,检查是否存在
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"group not found: %w"
,
err
)
}
// 订阅类型分组:先获取受影响的用户ID列表(用于事务后失效缓存)
var
affectedUserIDs
[]
int64
if
group
.
IsSubscriptionType
()
&&
s
.
billingCacheService
!=
nil
{
var
subscriptions
[]
model
.
UserSubscription
if
err
:=
s
.
groupRepo
.
DB
()
.
WithContext
(
ctx
)
.
Where
(
"group_id = ?"
,
id
)
.
Select
(
"user_id"
)
.
Find
(
&
subscriptions
)
.
Error
;
err
==
nil
{
for
_
,
sub
:=
range
subscriptions
{
affectedUserIDs
=
append
(
affectedUserIDs
,
sub
.
UserID
)
}
}
}
// 使用事务处理所有级联删除
db
:=
s
.
groupRepo
.
DB
()
err
=
db
.
WithContext
(
ctx
)
.
Transaction
(
func
(
tx
*
gorm
.
DB
)
error
{
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
if
group
.
IsSubscriptionType
()
{
if
err
:=
tx
.
Where
(
"group_id = ?"
,
id
)
.
Delete
(
&
model
.
UserSubscription
{})
.
Error
;
err
!=
nil
{
return
fmt
.
Errorf
(
"delete user subscriptions: %w"
,
err
)
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil(任何类型的分组都需要)
if
err
:=
tx
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
id
)
.
Update
(
"group_id"
,
nil
)
.
Error
;
err
!=
nil
{
return
fmt
.
Errorf
(
"clear api key group_id: %w"
,
err
)
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if
err
:=
tx
.
Model
(
&
model
.
User
{})
.
Where
(
"? = ANY(allowed_groups)"
,
id
)
.
Update
(
"allowed_groups"
,
gorm
.
Expr
(
"array_remove(allowed_groups, ?)"
,
id
))
.
Error
;
err
!=
nil
{
return
fmt
.
Errorf
(
"remove from allowed_groups: %w"
,
err
)
}
// 4. 删除 account_groups 中间表的数据
if
err
:=
tx
.
Where
(
"group_id = ?"
,
id
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
;
err
!=
nil
{
return
fmt
.
Errorf
(
"delete account groups: %w"
,
err
)
}
// 5. 删除分组本身
if
err
:=
tx
.
Delete
(
&
model
.
Group
{},
id
)
.
Error
;
err
!=
nil
{
return
fmt
.
Errorf
(
"delete group: %w"
,
err
)
}
return
nil
})
if
err
!=
nil
{
return
err
}
// 事务成功后,异步失效受影响用户的订阅缓存
if
len
(
affectedUserIDs
)
>
0
&&
s
.
billingCacheService
!=
nil
{
groupID
:=
id
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
for
_
,
userID
:=
range
affectedUserIDs
{
s
.
billingCacheService
.
InvalidateSubscription
(
cacheCtx
,
userID
,
groupID
)
}
}()
}
return
nil
}
func
(
s
*
adminServiceImpl
)
GetGroupAPIKeys
(
ctx
context
.
Context
,
groupID
int64
,
page
,
pageSize
int
)
([]
model
.
ApiKey
,
int64
,
error
)
{
params
:=
repository
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
keys
,
result
,
err
:=
s
.
apiKeyRepo
.
ListByGroupID
(
ctx
,
groupID
,
params
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
return
keys
,
result
.
Total
,
nil
}
// Account management implementations
func
(
s
*
adminServiceImpl
)
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
)
([]
model
.
Account
,
int64
,
error
)
{
params
:=
repository
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
accounts
,
result
,
err
:=
s
.
accountRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
accountType
,
status
,
search
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
return
accounts
,
result
.
Total
,
nil
}
func
(
s
*
adminServiceImpl
)
GetAccount
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
{
return
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
CreateAccount
(
ctx
context
.
Context
,
input
*
CreateAccountInput
)
(
*
model
.
Account
,
error
)
{
account
:=
&
model
.
Account
{
Name
:
input
.
Name
,
Platform
:
input
.
Platform
,
Type
:
input
.
Type
,
Credentials
:
model
.
JSONB
(
input
.
Credentials
),
Extra
:
model
.
JSONB
(
input
.
Extra
),
ProxyID
:
input
.
ProxyID
,
Concurrency
:
input
.
Concurrency
,
Priority
:
input
.
Priority
,
Status
:
model
.
StatusActive
,
}
if
err
:=
s
.
accountRepo
.
Create
(
ctx
,
account
);
err
!=
nil
{
return
nil
,
err
}
// 绑定分组
if
len
(
input
.
GroupIDs
)
>
0
{
if
err
:=
s
.
accountRepo
.
BindGroups
(
ctx
,
account
.
ID
,
input
.
GroupIDs
);
err
!=
nil
{
return
nil
,
err
}
}
return
account
,
nil
}
func
(
s
*
adminServiceImpl
)
UpdateAccount
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateAccountInput
)
(
*
model
.
Account
,
error
)
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
input
.
Name
!=
""
{
account
.
Name
=
input
.
Name
}
if
input
.
Type
!=
""
{
account
.
Type
=
input
.
Type
}
if
input
.
Credentials
!=
nil
&&
len
(
input
.
Credentials
)
>
0
{
account
.
Credentials
=
model
.
JSONB
(
input
.
Credentials
)
}
if
input
.
Extra
!=
nil
&&
len
(
input
.
Extra
)
>
0
{
account
.
Extra
=
model
.
JSONB
(
input
.
Extra
)
}
if
input
.
ProxyID
!=
nil
{
account
.
ProxyID
=
input
.
ProxyID
}
// 只在指针非 nil 时更新 Concurrency(支持设置为 0)
if
input
.
Concurrency
!=
nil
{
account
.
Concurrency
=
*
input
.
Concurrency
}
// 只在指针非 nil 时更新 Priority(支持设置为 0)
if
input
.
Priority
!=
nil
{
account
.
Priority
=
*
input
.
Priority
}
if
input
.
Status
!=
""
{
account
.
Status
=
input
.
Status
}
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
return
nil
,
err
}
// 更新分组绑定
if
input
.
GroupIDs
!=
nil
{
if
err
:=
s
.
accountRepo
.
BindGroups
(
ctx
,
account
.
ID
,
*
input
.
GroupIDs
);
err
!=
nil
{
return
nil
,
err
}
}
return
account
,
nil
}
func
(
s
*
adminServiceImpl
)
DeleteAccount
(
ctx
context
.
Context
,
id
int64
)
error
{
return
s
.
accountRepo
.
Delete
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
RefreshAccountCredentials
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
// TODO: Implement refresh logic
return
account
,
nil
}
func
(
s
*
adminServiceImpl
)
ClearAccountError
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
account
.
Status
=
model
.
StatusActive
account
.
ErrorMessage
=
""
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
return
nil
,
err
}
return
account
,
nil
}
func
(
s
*
adminServiceImpl
)
SetAccountSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
(
*
model
.
Account
,
error
)
{
if
err
:=
s
.
accountRepo
.
SetSchedulable
(
ctx
,
id
,
schedulable
);
err
!=
nil
{
return
nil
,
err
}
return
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
}
// Proxy management implementations
func
(
s
*
adminServiceImpl
)
ListProxies
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
)
([]
model
.
Proxy
,
int64
,
error
)
{
params
:=
repository
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
proxies
,
result
,
err
:=
s
.
proxyRepo
.
ListWithFilters
(
ctx
,
params
,
protocol
,
status
,
search
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
return
proxies
,
result
.
Total
,
nil
}
func
(
s
*
adminServiceImpl
)
GetAllProxies
(
ctx
context
.
Context
)
([]
model
.
Proxy
,
error
)
{
return
s
.
proxyRepo
.
ListActive
(
ctx
)
}
func
(
s
*
adminServiceImpl
)
GetAllProxiesWithAccountCount
(
ctx
context
.
Context
)
([]
model
.
ProxyWithAccountCount
,
error
)
{
return
s
.
proxyRepo
.
ListActiveWithAccountCount
(
ctx
)
}
func
(
s
*
adminServiceImpl
)
GetProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Proxy
,
error
)
{
return
s
.
proxyRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
CreateProxy
(
ctx
context
.
Context
,
input
*
CreateProxyInput
)
(
*
model
.
Proxy
,
error
)
{
proxy
:=
&
model
.
Proxy
{
Name
:
input
.
Name
,
Protocol
:
input
.
Protocol
,
Host
:
input
.
Host
,
Port
:
input
.
Port
,
Username
:
input
.
Username
,
Password
:
input
.
Password
,
Status
:
model
.
StatusActive
,
}
if
err
:=
s
.
proxyRepo
.
Create
(
ctx
,
proxy
);
err
!=
nil
{
return
nil
,
err
}
return
proxy
,
nil
}
func
(
s
*
adminServiceImpl
)
UpdateProxy
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateProxyInput
)
(
*
model
.
Proxy
,
error
)
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
input
.
Name
!=
""
{
proxy
.
Name
=
input
.
Name
}
if
input
.
Protocol
!=
""
{
proxy
.
Protocol
=
input
.
Protocol
}
if
input
.
Host
!=
""
{
proxy
.
Host
=
input
.
Host
}
if
input
.
Port
!=
0
{
proxy
.
Port
=
input
.
Port
}
if
input
.
Username
!=
""
{
proxy
.
Username
=
input
.
Username
}
if
input
.
Password
!=
""
{
proxy
.
Password
=
input
.
Password
}
if
input
.
Status
!=
""
{
proxy
.
Status
=
input
.
Status
}
if
err
:=
s
.
proxyRepo
.
Update
(
ctx
,
proxy
);
err
!=
nil
{
return
nil
,
err
}
return
proxy
,
nil
}
func
(
s
*
adminServiceImpl
)
DeleteProxy
(
ctx
context
.
Context
,
id
int64
)
error
{
return
s
.
proxyRepo
.
Delete
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
,
page
,
pageSize
int
)
([]
model
.
Account
,
int64
,
error
)
{
// Return mock data for now - would need a dedicated repository method
return
[]
model
.
Account
{},
0
,
nil
}
func
(
s
*
adminServiceImpl
)
CheckProxyExists
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
{
return
s
.
proxyRepo
.
ExistsByHostPortAuth
(
ctx
,
host
,
port
,
username
,
password
)
}
// Redeem code management implementations
func
(
s
*
adminServiceImpl
)
ListRedeemCodes
(
ctx
context
.
Context
,
page
,
pageSize
int
,
codeType
,
status
,
search
string
)
([]
model
.
RedeemCode
,
int64
,
error
)
{
params
:=
repository
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
codes
,
result
,
err
:=
s
.
redeemCodeRepo
.
ListWithFilters
(
ctx
,
params
,
codeType
,
status
,
search
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
return
codes
,
result
.
Total
,
nil
}
func
(
s
*
adminServiceImpl
)
GetRedeemCode
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
RedeemCode
,
error
)
{
return
s
.
redeemCodeRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
GenerateRedeemCodes
(
ctx
context
.
Context
,
input
*
GenerateRedeemCodesInput
)
([]
model
.
RedeemCode
,
error
)
{
// 如果是订阅类型,验证必须有 GroupID
if
input
.
Type
==
model
.
RedeemTypeSubscription
{
if
input
.
GroupID
==
nil
{
return
nil
,
errors
.
New
(
"group_id is required for subscription type"
)
}
// 验证分组存在且为订阅类型
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
input
.
GroupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"group not found: %w"
,
err
)
}
if
!
group
.
IsSubscriptionType
()
{
return
nil
,
errors
.
New
(
"group must be subscription type"
)
}
}
codes
:=
make
([]
model
.
RedeemCode
,
0
,
input
.
Count
)
for
i
:=
0
;
i
<
input
.
Count
;
i
++
{
code
:=
model
.
RedeemCode
{
Code
:
model
.
GenerateRedeemCode
(),
Type
:
input
.
Type
,
Value
:
input
.
Value
,
Status
:
model
.
StatusUnused
,
}
// 订阅类型专用字段
if
input
.
Type
==
model
.
RedeemTypeSubscription
{
code
.
GroupID
=
input
.
GroupID
code
.
ValidityDays
=
input
.
ValidityDays
if
code
.
ValidityDays
<=
0
{
code
.
ValidityDays
=
30
// 默认30天
}
}
if
err
:=
s
.
redeemCodeRepo
.
Create
(
ctx
,
&
code
);
err
!=
nil
{
return
nil
,
err
}
codes
=
append
(
codes
,
code
)
}
return
codes
,
nil
}
func
(
s
*
adminServiceImpl
)
DeleteRedeemCode
(
ctx
context
.
Context
,
id
int64
)
error
{
return
s
.
redeemCodeRepo
.
Delete
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
BatchDeleteRedeemCodes
(
ctx
context
.
Context
,
ids
[]
int64
)
(
int64
,
error
)
{
var
deleted
int64
for
_
,
id
:=
range
ids
{
if
err
:=
s
.
redeemCodeRepo
.
Delete
(
ctx
,
id
);
err
==
nil
{
deleted
++
}
}
return
deleted
,
nil
}
func
(
s
*
adminServiceImpl
)
ExpireRedeemCode
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
RedeemCode
,
error
)
{
code
,
err
:=
s
.
redeemCodeRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
code
.
Status
=
model
.
StatusExpired
if
err
:=
s
.
redeemCodeRepo
.
Update
(
ctx
,
code
);
err
!=
nil
{
return
nil
,
err
}
return
code
,
nil
}
func
(
s
*
adminServiceImpl
)
TestProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
ProxyTestResult
,
error
)
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
return
testProxyConnection
(
ctx
,
proxy
)
}
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
func
testProxyConnection
(
ctx
context
.
Context
,
proxy
*
model
.
Proxy
)
(
*
ProxyTestResult
,
error
)
{
proxyURL
:=
proxy
.
URL
()
// Create HTTP client with proxy
transport
,
err
:=
createProxyTransport
(
proxyURL
)
if
err
!=
nil
{
return
&
ProxyTestResult
{
Success
:
false
,
Message
:
fmt
.
Sprintf
(
"Failed to create proxy transport: %v"
,
err
),
},
nil
}
client
:=
&
http
.
Client
{
Transport
:
transport
,
Timeout
:
15
*
time
.
Second
,
}
// Measure latency
startTime
:=
time
.
Now
()
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
"https://ipinfo.io/json"
,
nil
)
if
err
!=
nil
{
return
&
ProxyTestResult
{
Success
:
false
,
Message
:
fmt
.
Sprintf
(
"Failed to create request: %v"
,
err
),
},
nil
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
&
ProxyTestResult
{
Success
:
false
,
Message
:
fmt
.
Sprintf
(
"Proxy connection failed: %v"
,
err
),
},
nil
}
defer
resp
.
Body
.
Close
()
latencyMs
:=
time
.
Since
(
startTime
)
.
Milliseconds
()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
&
ProxyTestResult
{
Success
:
false
,
Message
:
fmt
.
Sprintf
(
"Request failed with status: %d"
,
resp
.
StatusCode
),
LatencyMs
:
latencyMs
,
},
nil
}
// Parse ipinfo.io response
var
ipInfo
struct
{
IP
string
`json:"ip"`
City
string
`json:"city"`
Region
string
`json:"region"`
Country
string
`json:"country"`
}
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
&
ProxyTestResult
{
Success
:
true
,
Message
:
"Proxy is accessible but failed to read response"
,
LatencyMs
:
latencyMs
,
},
nil
}
if
err
:=
json
.
Unmarshal
(
body
,
&
ipInfo
);
err
!=
nil
{
return
&
ProxyTestResult
{
Success
:
true
,
Message
:
"Proxy is accessible but failed to parse response"
,
LatencyMs
:
latencyMs
,
},
nil
}
return
&
ProxyTestResult
{
Success
:
true
,
Message
:
"Proxy is accessible"
,
LatencyMs
:
latencyMs
,
IPAddress
:
ipInfo
.
IP
,
City
:
ipInfo
.
City
,
Region
:
ipInfo
.
Region
,
Country
:
ipInfo
.
Country
,
},
nil
}
// createProxyTransport creates an HTTP transport with the given proxy URL
func
createProxyTransport
(
proxyURL
string
)
(
*
http
.
Transport
,
error
)
{
parsedURL
,
err
:=
url
.
Parse
(
proxyURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid proxy URL: %w"
,
err
)
}
transport
:=
&
http
.
Transport
{
TLSClientConfig
:
&
tls
.
Config
{
InsecureSkipVerify
:
true
},
}
switch
parsedURL
.
Scheme
{
case
"http"
,
"https"
:
transport
.
Proxy
=
http
.
ProxyURL
(
parsedURL
)
case
"socks5"
:
dialer
,
err
:=
proxy
.
FromURL
(
parsedURL
,
proxy
.
Direct
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to create socks5 dialer: %w"
,
err
)
}
transport
.
DialContext
=
func
(
ctx
context
.
Context
,
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
return
dialer
.
Dial
(
network
,
addr
)
}
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported proxy protocol: %s"
,
parsedURL
.
Scheme
)
}
return
transport
,
nil
}
backend/internal/service/api_key_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var
(
ErrApiKeyNotFound
=
errors
.
New
(
"api key not found"
)
ErrGroupNotAllowed
=
errors
.
New
(
"user is not allowed to bind this group"
)
ErrApiKeyExists
=
errors
.
New
(
"api key already exists"
)
ErrApiKeyTooShort
=
errors
.
New
(
"api key must be at least 16 characters"
)
ErrApiKeyInvalidChars
=
errors
.
New
(
"api key can only contain letters, numbers, underscores, and hyphens"
)
ErrApiKeyRateLimited
=
errors
.
New
(
"too many failed attempts, please try again later"
)
)
const
(
apiKeyRateLimitKeyPrefix
=
"apikey:create_rate_limit:"
apiKeyMaxErrorsPerHour
=
20
apiKeyRateLimitDuration
=
time
.
Hour
)
// CreateApiKeyRequest 创建API Key请求
type
CreateApiKeyRequest
struct
{
Name
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type
UpdateApiKeyRequest
struct
{
Name
*
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
Status
*
string
`json:"status"`
}
// ApiKeyService API Key服务
type
ApiKeyService
struct
{
apiKeyRepo
*
repository
.
ApiKeyRepository
userRepo
*
repository
.
UserRepository
groupRepo
*
repository
.
GroupRepository
userSubRepo
*
repository
.
UserSubscriptionRepository
rdb
*
redis
.
Client
cfg
*
config
.
Config
}
// NewApiKeyService 创建API Key服务实例
func
NewApiKeyService
(
apiKeyRepo
*
repository
.
ApiKeyRepository
,
userRepo
*
repository
.
UserRepository
,
groupRepo
*
repository
.
GroupRepository
,
userSubRepo
*
repository
.
UserSubscriptionRepository
,
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
,
)
*
ApiKeyService
{
return
&
ApiKeyService
{
apiKeyRepo
:
apiKeyRepo
,
userRepo
:
userRepo
,
groupRepo
:
groupRepo
,
userSubRepo
:
userSubRepo
,
rdb
:
rdb
,
cfg
:
cfg
,
}
}
// GenerateKey 生成随机API Key
func
(
s
*
ApiKeyService
)
GenerateKey
()
(
string
,
error
)
{
// 生成32字节随机数据
bytes
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate random bytes: %w"
,
err
)
}
// 转换为十六进制字符串并添加前缀
prefix
:=
s
.
cfg
.
Default
.
ApiKeyPrefix
if
prefix
==
""
{
prefix
=
"sk-"
}
key
:=
prefix
+
hex
.
EncodeToString
(
bytes
)
return
key
,
nil
}
// ValidateCustomKey 验证自定义API Key格式
func
(
s
*
ApiKeyService
)
ValidateCustomKey
(
key
string
)
error
{
// 检查长度
if
len
(
key
)
<
16
{
return
ErrApiKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
for
_
,
c
:=
range
key
{
if
!
((
c
>=
'a'
&&
c
<=
'z'
)
||
(
c
>=
'A'
&&
c
<=
'Z'
)
||
(
c
>=
'0'
&&
c
<=
'9'
)
||
c
==
'_'
||
c
==
'-'
)
{
return
ErrApiKeyInvalidChars
}
}
return
nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func
(
s
*
ApiKeyService
)
checkApiKeyRateLimit
(
ctx
context
.
Context
,
userID
int64
)
error
{
if
s
.
rdb
==
nil
{
return
nil
}
key
:=
fmt
.
Sprintf
(
"%s%d"
,
apiKeyRateLimitKeyPrefix
,
userID
)
count
,
err
:=
s
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
// Redis 出错时不阻止用户操作
return
nil
}
if
count
>=
apiKeyMaxErrorsPerHour
{
return
ErrApiKeyRateLimited
}
return
nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func
(
s
*
ApiKeyService
)
incrementApiKeyErrorCount
(
ctx
context
.
Context
,
userID
int64
)
{
if
s
.
rdb
==
nil
{
return
}
key
:=
fmt
.
Sprintf
(
"%s%d"
,
apiKeyRateLimitKeyPrefix
,
userID
)
pipe
:=
s
.
rdb
.
Pipeline
()
pipe
.
Incr
(
ctx
,
key
)
pipe
.
Expire
(
ctx
,
key
,
apiKeyRateLimitDuration
)
_
,
_
=
pipe
.
Exec
(
ctx
)
}
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func
(
s
*
ApiKeyService
)
canUserBindGroup
(
ctx
context
.
Context
,
user
*
model
.
User
,
group
*
model
.
Group
)
bool
{
// 订阅类型分组:需要有效订阅
if
group
.
IsSubscriptionType
()
{
_
,
err
:=
s
.
userSubRepo
.
GetActiveByUserIDAndGroupID
(
ctx
,
user
.
ID
,
group
.
ID
)
return
err
==
nil
// 有有效订阅则允许
}
// 标准类型分组:使用原有逻辑
return
user
.
CanBindGroup
(
group
.
ID
,
group
.
IsExclusive
)
}
// Create 创建API Key
func
(
s
*
ApiKeyService
)
Create
(
ctx
context
.
Context
,
userID
int64
,
req
CreateApiKeyRequest
)
(
*
model
.
ApiKey
,
error
)
{
// 验证用户存在
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrUserNotFound
}
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
// 验证分组权限(如果指定了分组)
if
req
.
GroupID
!=
nil
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
req
.
GroupID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
errors
.
New
(
"group not found"
)
}
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
// 检查用户是否可以绑定该分组
if
!
s
.
canUserBindGroup
(
ctx
,
user
,
group
)
{
return
nil
,
ErrGroupNotAllowed
}
}
var
key
string
// 判断是否使用自定义Key
if
req
.
CustomKey
!=
nil
&&
*
req
.
CustomKey
!=
""
{
// 检查限流(仅对自定义key进行限流)
if
err
:=
s
.
checkApiKeyRateLimit
(
ctx
,
userID
);
err
!=
nil
{
return
nil
,
err
}
// 验证自定义Key格式
if
err
:=
s
.
ValidateCustomKey
(
*
req
.
CustomKey
);
err
!=
nil
{
return
nil
,
err
}
// 检查Key是否已存在
exists
,
err
:=
s
.
apiKeyRepo
.
ExistsByKey
(
ctx
,
*
req
.
CustomKey
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check key exists: %w"
,
err
)
}
if
exists
{
// Key已存在,增加错误计数
s
.
incrementApiKeyErrorCount
(
ctx
,
userID
)
return
nil
,
ErrApiKeyExists
}
key
=
*
req
.
CustomKey
}
else
{
// 生成随机API Key
var
err
error
key
,
err
=
s
.
GenerateKey
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"generate key: %w"
,
err
)
}
}
// 创建API Key记录
apiKey
:=
&
model
.
ApiKey
{
UserID
:
userID
,
Key
:
key
,
Name
:
req
.
Name
,
GroupID
:
req
.
GroupID
,
Status
:
model
.
StatusActive
,
}
if
err
:=
s
.
apiKeyRepo
.
Create
(
ctx
,
apiKey
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create api key: %w"
,
err
)
}
return
apiKey
,
nil
}
// List 获取用户的API Key列表
func
(
s
*
ApiKeyService
)
List
(
ctx
context
.
Context
,
userID
int64
,
params
repository
.
PaginationParams
)
([]
model
.
ApiKey
,
*
repository
.
PaginationResult
,
error
)
{
keys
,
pagination
,
err
:=
s
.
apiKeyRepo
.
ListByUserID
(
ctx
,
userID
,
params
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list api keys: %w"
,
err
)
}
return
keys
,
pagination
,
nil
}
// GetByID 根据ID获取API Key
func
(
s
*
ApiKeyService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
ApiKey
,
error
)
{
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrApiKeyNotFound
}
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
return
apiKey
,
nil
}
// GetByKey 根据Key字符串获取API Key(用于认证)
func
(
s
*
ApiKeyService
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
ApiKey
,
error
)
{
// 尝试从Redis缓存获取
cacheKey
:=
fmt
.
Sprintf
(
"apikey:%s"
,
key
)
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByKey
(
ctx
,
key
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrApiKeyNotFound
}
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
// 缓存到Redis(可选,TTL设置为5分钟)
if
s
.
rdb
!=
nil
{
// 这里可以序列化并缓存API Key
_
=
cacheKey
// 使用变量避免未使用错误
}
return
apiKey
,
nil
}
// Update 更新API Key
func
(
s
*
ApiKeyService
)
Update
(
ctx
context
.
Context
,
id
int64
,
userID
int64
,
req
UpdateApiKeyRequest
)
(
*
model
.
ApiKey
,
error
)
{
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrApiKeyNotFound
}
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
// 验证所有权
if
apiKey
.
UserID
!=
userID
{
return
nil
,
ErrInsufficientPerms
}
// 更新字段
if
req
.
Name
!=
nil
{
apiKey
.
Name
=
*
req
.
Name
}
if
req
.
GroupID
!=
nil
{
// 验证分组权限
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
req
.
GroupID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
errors
.
New
(
"group not found"
)
}
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
if
!
s
.
canUserBindGroup
(
ctx
,
user
,
group
)
{
return
nil
,
ErrGroupNotAllowed
}
apiKey
.
GroupID
=
req
.
GroupID
}
if
req
.
Status
!=
nil
{
apiKey
.
Status
=
*
req
.
Status
// 如果状态改变,清除Redis缓存
if
s
.
rdb
!=
nil
{
cacheKey
:=
fmt
.
Sprintf
(
"apikey:%s"
,
apiKey
.
Key
)
_
=
s
.
rdb
.
Del
(
ctx
,
cacheKey
)
}
}
if
err
:=
s
.
apiKeyRepo
.
Update
(
ctx
,
apiKey
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update api key: %w"
,
err
)
}
return
apiKey
,
nil
}
// Delete 删除API Key
func
(
s
*
ApiKeyService
)
Delete
(
ctx
context
.
Context
,
id
int64
,
userID
int64
)
error
{
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
ErrApiKeyNotFound
}
return
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
// 验证所有权
if
apiKey
.
UserID
!=
userID
{
return
ErrInsufficientPerms
}
// 清除Redis缓存
if
s
.
rdb
!=
nil
{
cacheKey
:=
fmt
.
Sprintf
(
"apikey:%s"
,
apiKey
.
Key
)
_
=
s
.
rdb
.
Del
(
ctx
,
cacheKey
)
}
if
err
:=
s
.
apiKeyRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete api key: %w"
,
err
)
}
return
nil
}
// ValidateKey 验证API Key是否有效(用于认证中间件)
func
(
s
*
ApiKeyService
)
ValidateKey
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
ApiKey
,
*
model
.
User
,
error
)
{
// 获取API Key
apiKey
,
err
:=
s
.
GetByKey
(
ctx
,
key
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
// 检查API Key状态
if
!
apiKey
.
IsActive
()
{
return
nil
,
nil
,
errors
.
New
(
"api key is not active"
)
}
// 获取用户信息
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
apiKey
.
UserID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
nil
,
ErrUserNotFound
}
return
nil
,
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
// 检查用户状态
if
!
user
.
IsActive
()
{
return
nil
,
nil
,
ErrUserNotActive
}
return
apiKey
,
user
,
nil
}
// IncrementUsage 增加API Key使用次数(可选:用于统计)
func
(
s
*
ApiKeyService
)
IncrementUsage
(
ctx
context
.
Context
,
keyID
int64
)
error
{
// 使用Redis计数器
if
s
.
rdb
!=
nil
{
cacheKey
:=
fmt
.
Sprintf
(
"apikey:usage:%d:%s"
,
keyID
,
timezone
.
Now
()
.
Format
(
"2006-01-02"
))
if
err
:=
s
.
rdb
.
Incr
(
ctx
,
cacheKey
)
.
Err
();
err
!=
nil
{
return
fmt
.
Errorf
(
"increment usage: %w"
,
err
)
}
// 设置24小时过期
_
=
s
.
rdb
.
Expire
(
ctx
,
cacheKey
,
24
*
time
.
Hour
)
}
return
nil
}
// GetAvailableGroups 获取用户有权限绑定的分组列表
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func
(
s
*
ApiKeyService
)
GetAvailableGroups
(
ctx
context
.
Context
,
userID
int64
)
([]
model
.
Group
,
error
)
{
// 获取用户信息
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrUserNotFound
}
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
// 获取所有活跃分组
allGroups
,
err
:=
s
.
groupRepo
.
ListActive
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list active groups: %w"
,
err
)
}
// 获取用户的所有有效订阅
activeSubscriptions
,
err
:=
s
.
userSubRepo
.
ListActiveByUserID
(
ctx
,
userID
)
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
fmt
.
Errorf
(
"list active subscriptions: %w"
,
err
)
}
// 构建订阅分组 ID 集合
subscribedGroupIDs
:=
make
(
map
[
int64
]
bool
)
for
_
,
sub
:=
range
activeSubscriptions
{
subscribedGroupIDs
[
sub
.
GroupID
]
=
true
}
// 过滤出用户有权限的分组
availableGroups
:=
make
([]
model
.
Group
,
0
)
for
_
,
group
:=
range
allGroups
{
if
s
.
canUserBindGroupInternal
(
user
,
&
group
,
subscribedGroupIDs
)
{
availableGroups
=
append
(
availableGroups
,
group
)
}
}
return
availableGroups
,
nil
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func
(
s
*
ApiKeyService
)
canUserBindGroupInternal
(
user
*
model
.
User
,
group
*
model
.
Group
,
subscribedGroupIDs
map
[
int64
]
bool
)
bool
{
// 订阅类型分组:需要有效订阅
if
group
.
IsSubscriptionType
()
{
return
subscribedGroupIDs
[
group
.
ID
]
}
// 标准类型分组:使用原有逻辑
return
user
.
CanBindGroup
(
group
.
ID
,
group
.
IsExclusive
)
}
backend/internal/service/auth_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"errors"
"fmt"
"log"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var
(
ErrInvalidCredentials
=
errors
.
New
(
"invalid email or password"
)
ErrUserNotActive
=
errors
.
New
(
"user is not active"
)
ErrEmailExists
=
errors
.
New
(
"email already exists"
)
ErrInvalidToken
=
errors
.
New
(
"invalid token"
)
ErrTokenExpired
=
errors
.
New
(
"token has expired"
)
ErrEmailVerifyRequired
=
errors
.
New
(
"email verification is required"
)
ErrRegDisabled
=
errors
.
New
(
"registration is currently disabled"
)
)
// JWTClaims JWT载荷数据
type
JWTClaims
struct
{
UserID
int64
`json:"user_id"`
Email
string
`json:"email"`
Role
string
`json:"role"`
jwt
.
RegisteredClaims
}
// AuthService 认证服务
type
AuthService
struct
{
userRepo
*
repository
.
UserRepository
cfg
*
config
.
Config
settingService
*
SettingService
emailService
*
EmailService
turnstileService
*
TurnstileService
emailQueueService
*
EmailQueueService
}
// NewAuthService 创建认证服务实例
func
NewAuthService
(
userRepo
*
repository
.
UserRepository
,
cfg
*
config
.
Config
)
*
AuthService
{
return
&
AuthService
{
userRepo
:
userRepo
,
cfg
:
cfg
,
}
}
// SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证)
func
(
s
*
AuthService
)
SetSettingService
(
settingService
*
SettingService
)
{
s
.
settingService
=
settingService
}
// SetEmailService 设置邮件服务(用于邮件验证)
func
(
s
*
AuthService
)
SetEmailService
(
emailService
*
EmailService
)
{
s
.
emailService
=
emailService
}
// SetTurnstileService 设置Turnstile服务(用于验证码校验)
func
(
s
*
AuthService
)
SetTurnstileService
(
turnstileService
*
TurnstileService
)
{
s
.
turnstileService
=
turnstileService
}
// SetEmailQueueService 设置邮件队列服务(用于异步发送邮件)
func
(
s
*
AuthService
)
SetEmailQueueService
(
emailQueueService
*
EmailQueueService
)
{
s
.
emailQueueService
=
emailQueueService
}
// Register 用户注册,返回token和用户
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
model
.
User
,
error
)
{
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
)
}
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
string
)
(
string
,
*
model
.
User
,
error
)
{
// 检查是否开放注册
if
s
.
settingService
!=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
}
// 检查是否需要邮件验证
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
if
verifyCode
==
""
{
return
""
,
nil
,
ErrEmailVerifyRequired
}
// 验证邮箱验证码
if
s
.
emailService
!=
nil
{
if
err
:=
s
.
emailService
.
VerifyCode
(
ctx
,
email
,
verifyCode
);
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"verify code: %w"
,
err
)
}
}
}
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"check email exists: %w"
,
err
)
}
if
existsEmail
{
return
""
,
nil
,
ErrEmailExists
}
// 密码哈希
hashedPassword
,
err
:=
s
.
HashPassword
(
password
)
if
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"hash password: %w"
,
err
)
}
// 获取默认配置
defaultBalance
:=
s
.
cfg
.
Default
.
UserBalance
defaultConcurrency
:=
s
.
cfg
.
Default
.
UserConcurrency
if
s
.
settingService
!=
nil
{
defaultBalance
=
s
.
settingService
.
GetDefaultBalance
(
ctx
)
defaultConcurrency
=
s
.
settingService
.
GetDefaultConcurrency
(
ctx
)
}
// 创建用户
user
:=
&
model
.
User
{
Email
:
email
,
PasswordHash
:
hashedPassword
,
Role
:
model
.
RoleUser
,
Balance
:
defaultBalance
,
Concurrency
:
defaultConcurrency
,
Status
:
model
.
StatusActive
,
}
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
user
);
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"create user: %w"
,
err
)
}
// 生成token
token
,
err
:=
s
.
GenerateToken
(
user
)
if
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
return
token
,
user
,
nil
}
// SendVerifyCodeResult 发送验证码返回结果
type
SendVerifyCodeResult
struct
{
Countdown
int
`json:"countdown"`
// 倒计时秒数
}
// SendVerifyCode 发送邮箱验证码(同步方式)
func
(
s
*
AuthService
)
SendVerifyCode
(
ctx
context
.
Context
,
email
string
)
error
{
// 检查是否开放注册
if
s
.
settingService
!=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
ErrRegDisabled
}
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check email exists: %w"
,
err
)
}
if
existsEmail
{
return
ErrEmailExists
}
// 发送验证码
if
s
.
emailService
==
nil
{
return
errors
.
New
(
"email service not configured"
)
}
// 获取网站名称
siteName
:=
"Sub2API"
if
s
.
settingService
!=
nil
{
siteName
=
s
.
settingService
.
GetSiteName
(
ctx
)
}
return
s
.
emailService
.
SendVerifyCode
(
ctx
,
email
,
siteName
)
}
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
func
(
s
*
AuthService
)
SendVerifyCodeAsync
(
ctx
context
.
Context
,
email
string
)
(
*
SendVerifyCodeResult
,
error
)
{
log
.
Printf
(
"[Auth] SendVerifyCodeAsync called for email: %s"
,
email
)
// 检查是否开放注册
if
s
.
settingService
!=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
log
.
Println
(
"[Auth] Registration is disabled"
)
return
nil
,
ErrRegDisabled
}
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Error checking email exists: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"check email exists: %w"
,
err
)
}
if
existsEmail
{
log
.
Printf
(
"[Auth] Email already exists: %s"
,
email
)
return
nil
,
ErrEmailExists
}
// 检查邮件队列服务是否配置
if
s
.
emailQueueService
==
nil
{
log
.
Println
(
"[Auth] Email queue service not configured"
)
return
nil
,
errors
.
New
(
"email queue service not configured"
)
}
// 获取网站名称
siteName
:=
"Sub2API"
if
s
.
settingService
!=
nil
{
siteName
=
s
.
settingService
.
GetSiteName
(
ctx
)
}
// 异步发送
log
.
Printf
(
"[Auth] Enqueueing verify code for: %s"
,
email
)
if
err
:=
s
.
emailQueueService
.
EnqueueVerifyCode
(
email
,
siteName
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to enqueue: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"enqueue verify code: %w"
,
err
)
}
log
.
Printf
(
"[Auth] Verify code enqueued successfully for: %s"
,
email
)
return
&
SendVerifyCodeResult
{
Countdown
:
60
,
// 60秒倒计时
},
nil
}
// VerifyTurnstile 验证Turnstile token
func
(
s
*
AuthService
)
VerifyTurnstile
(
ctx
context
.
Context
,
token
string
,
remoteIP
string
)
error
{
if
s
.
turnstileService
==
nil
{
return
nil
// 服务未配置则跳过验证
}
return
s
.
turnstileService
.
VerifyToken
(
ctx
,
token
,
remoteIP
)
}
// IsTurnstileEnabled 检查是否启用Turnstile验证
func
(
s
*
AuthService
)
IsTurnstileEnabled
(
ctx
context
.
Context
)
bool
{
if
s
.
turnstileService
==
nil
{
return
false
}
return
s
.
turnstileService
.
IsEnabled
(
ctx
)
}
// IsRegistrationEnabled 检查是否开放注册
func
(
s
*
AuthService
)
IsRegistrationEnabled
(
ctx
context
.
Context
)
bool
{
if
s
.
settingService
==
nil
{
return
true
}
return
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
func
(
s
*
AuthService
)
IsEmailVerifyEnabled
(
ctx
context
.
Context
)
bool
{
if
s
.
settingService
==
nil
{
return
false
}
return
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
}
// Login 用户登录,返回JWT token
func
(
s
*
AuthService
)
Login
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
model
.
User
,
error
)
{
// 查找用户
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
""
,
nil
,
ErrInvalidCredentials
}
return
""
,
nil
,
fmt
.
Errorf
(
"get user by email: %w"
,
err
)
}
// 验证密码
if
!
s
.
CheckPassword
(
password
,
user
.
PasswordHash
)
{
return
""
,
nil
,
ErrInvalidCredentials
}
// 检查用户状态
if
!
user
.
IsActive
()
{
return
""
,
nil
,
ErrUserNotActive
}
// 生成JWT token
token
,
err
:=
s
.
GenerateToken
(
user
)
if
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
return
token
,
user
,
nil
}
// ValidateToken 验证JWT token并返回用户声明
func
(
s
*
AuthService
)
ValidateToken
(
tokenString
string
)
(
*
JWTClaims
,
error
)
{
token
,
err
:=
jwt
.
ParseWithClaims
(
tokenString
,
&
JWTClaims
{},
func
(
token
*
jwt
.
Token
)
(
interface
{},
error
)
{
// 验证签名方法
if
_
,
ok
:=
token
.
Method
.
(
*
jwt
.
SigningMethodHMAC
);
!
ok
{
return
nil
,
fmt
.
Errorf
(
"unexpected signing method: %v"
,
token
.
Header
[
"alg"
])
}
return
[]
byte
(
s
.
cfg
.
JWT
.
Secret
),
nil
})
if
err
!=
nil
{
if
errors
.
Is
(
err
,
jwt
.
ErrTokenExpired
)
{
return
nil
,
ErrTokenExpired
}
return
nil
,
ErrInvalidToken
}
if
claims
,
ok
:=
token
.
Claims
.
(
*
JWTClaims
);
ok
&&
token
.
Valid
{
return
claims
,
nil
}
return
nil
,
ErrInvalidToken
}
// GenerateToken 生成JWT token
func
(
s
*
AuthService
)
GenerateToken
(
user
*
model
.
User
)
(
string
,
error
)
{
now
:=
time
.
Now
()
expiresAt
:=
now
.
Add
(
time
.
Duration
(
s
.
cfg
.
JWT
.
ExpireHour
)
*
time
.
Hour
)
claims
:=
&
JWTClaims
{
UserID
:
user
.
ID
,
Email
:
user
.
Email
,
Role
:
user
.
Role
,
RegisteredClaims
:
jwt
.
RegisteredClaims
{
ExpiresAt
:
jwt
.
NewNumericDate
(
expiresAt
),
IssuedAt
:
jwt
.
NewNumericDate
(
now
),
NotBefore
:
jwt
.
NewNumericDate
(
now
),
},
}
token
:=
jwt
.
NewWithClaims
(
jwt
.
SigningMethodHS256
,
claims
)
tokenString
,
err
:=
token
.
SignedString
([]
byte
(
s
.
cfg
.
JWT
.
Secret
))
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"sign token: %w"
,
err
)
}
return
tokenString
,
nil
}
// HashPassword 使用bcrypt加密密码
func
(
s
*
AuthService
)
HashPassword
(
password
string
)
(
string
,
error
)
{
hashedBytes
,
err
:=
bcrypt
.
GenerateFromPassword
([]
byte
(
password
),
bcrypt
.
DefaultCost
)
if
err
!=
nil
{
return
""
,
err
}
return
string
(
hashedBytes
),
nil
}
// CheckPassword 验证密码是否匹配
func
(
s
*
AuthService
)
CheckPassword
(
password
,
hashedPassword
string
)
bool
{
err
:=
bcrypt
.
CompareHashAndPassword
([]
byte
(
hashedPassword
),
[]
byte
(
password
))
return
err
==
nil
}
// RefreshToken 刷新token
func
(
s
*
AuthService
)
RefreshToken
(
ctx
context
.
Context
,
oldTokenString
string
)
(
string
,
error
)
{
// 验证旧token(即使过期也允许,用于刷新)
claims
,
err
:=
s
.
ValidateToken
(
oldTokenString
)
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
ErrTokenExpired
)
{
return
""
,
err
}
// 获取最新的用户信息
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
claims
.
UserID
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
""
,
ErrInvalidToken
}
return
""
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
// 检查用户状态
if
!
user
.
IsActive
()
{
return
""
,
ErrUserNotActive
}
// 生成新token
return
s
.
GenerateToken
(
user
)
}
backend/internal/service/billing_cache_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// 缓存Key前缀和TTL
const
(
billingBalanceKeyPrefix
=
"billing:balance:"
billingSubKeyPrefix
=
"billing:sub:"
billingCacheTTL
=
5
*
time
.
Minute
)
// 订阅缓存Hash字段
const
(
subFieldStatus
=
"status"
subFieldExpiresAt
=
"expires_at"
subFieldDailyUsage
=
"daily_usage"
subFieldWeeklyUsage
=
"weekly_usage"
subFieldMonthlyUsage
=
"monthly_usage"
subFieldVersion
=
"version"
)
// 错误定义
// 注:ErrInsufficientBalance在redeem_service.go中定义
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var
(
ErrSubscriptionInvalid
=
errors
.
New
(
"subscription is invalid or expired"
)
)
// 预编译的Lua脚本
var
(
// deductBalanceScript: 扣减余额缓存,key不存在则忽略
deductBalanceScript
=
redis
.
NewScript
(
`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`
)
// updateSubUsageScript: 更新订阅用量缓存,key不存在则忽略
updateSubUsageScript
=
redis
.
NewScript
(
`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`
)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
type
subscriptionCacheData
struct
{
Status
string
ExpiresAt
time
.
Time
DailyUsage
float64
WeeklyUsage
float64
MonthlyUsage
float64
Version
int64
}
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type
BillingCacheService
struct
{
rdb
*
redis
.
Client
userRepo
*
repository
.
UserRepository
subRepo
*
repository
.
UserSubscriptionRepository
}
// NewBillingCacheService 创建计费缓存服务
func
NewBillingCacheService
(
rdb
*
redis
.
Client
,
userRepo
*
repository
.
UserRepository
,
subRepo
*
repository
.
UserSubscriptionRepository
)
*
BillingCacheService
{
return
&
BillingCacheService
{
rdb
:
rdb
,
userRepo
:
userRepo
,
subRepo
:
subRepo
,
}
}
// ============================================
// 余额缓存方法
// ============================================
// GetUserBalance 获取用户余额(优先从缓存读取)
func
(
s
*
BillingCacheService
)
GetUserBalance
(
ctx
context
.
Context
,
userID
int64
)
(
float64
,
error
)
{
if
s
.
rdb
==
nil
{
// Redis不可用,直接查询数据库
return
s
.
getUserBalanceFromDB
(
ctx
,
userID
)
}
key
:=
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
// 尝试从缓存读取
val
,
err
:=
s
.
rdb
.
Get
(
ctx
,
key
)
.
Result
()
if
err
==
nil
{
balance
,
parseErr
:=
strconv
.
ParseFloat
(
val
,
64
)
if
parseErr
==
nil
{
return
balance
,
nil
}
}
// 缓存未命中或解析错误,从数据库读取
balance
,
err
:=
s
.
getUserBalanceFromDB
(
ctx
,
userID
)
if
err
!=
nil
{
return
0
,
err
}
// 异步建立缓存
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
s
.
setBalanceCache
(
cacheCtx
,
userID
,
balance
)
}()
return
balance
,
nil
}
// getUserBalanceFromDB 从数据库获取用户余额
func
(
s
*
BillingCacheService
)
getUserBalanceFromDB
(
ctx
context
.
Context
,
userID
int64
)
(
float64
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"get user balance: %w"
,
err
)
}
return
user
.
Balance
,
nil
}
// setBalanceCache 设置余额缓存
func
(
s
*
BillingCacheService
)
setBalanceCache
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
)
{
if
s
.
rdb
==
nil
{
return
}
key
:=
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
if
err
:=
s
.
rdb
.
Set
(
ctx
,
key
,
balance
,
billingCacheTTL
)
.
Err
();
err
!=
nil
{
log
.
Printf
(
"Warning: set balance cache failed for user %d: %v"
,
userID
,
err
)
}
}
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
func
(
s
*
BillingCacheService
)
DeductBalanceCache
(
ctx
context
.
Context
,
userID
int64
,
amount
float64
)
error
{
if
s
.
rdb
==
nil
{
return
nil
}
key
:=
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
// 使用预编译的Lua脚本原子性扣减,如果key不存在则忽略
_
,
err
:=
deductBalanceScript
.
Run
(
ctx
,
s
.
rdb
,
[]
string
{
key
},
amount
,
int
(
billingCacheTTL
.
Seconds
()))
.
Result
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
log
.
Printf
(
"Warning: deduct balance cache failed for user %d: %v"
,
userID
,
err
)
}
return
nil
}
// InvalidateUserBalance 失效用户余额缓存
func
(
s
*
BillingCacheService
)
InvalidateUserBalance
(
ctx
context
.
Context
,
userID
int64
)
error
{
if
s
.
rdb
==
nil
{
return
nil
}
key
:=
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
if
err
:=
s
.
rdb
.
Del
(
ctx
,
key
)
.
Err
();
err
!=
nil
{
log
.
Printf
(
"Warning: invalidate balance cache failed for user %d: %v"
,
userID
,
err
)
return
err
}
return
nil
}
// ============================================
// 订阅缓存方法
// ============================================
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
func
(
s
*
BillingCacheService
)
GetSubscriptionStatus
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
subscriptionCacheData
,
error
)
{
if
s
.
rdb
==
nil
{
return
s
.
getSubscriptionFromDB
(
ctx
,
userID
,
groupID
)
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
billingSubKeyPrefix
,
userID
,
groupID
)
// 尝试从缓存读取
result
,
err
:=
s
.
rdb
.
HGetAll
(
ctx
,
key
)
.
Result
()
if
err
==
nil
&&
len
(
result
)
>
0
{
data
,
parseErr
:=
s
.
parseSubscriptionCache
(
result
)
if
parseErr
==
nil
{
return
data
,
nil
}
}
// 缓存未命中,从数据库读取
data
,
err
:=
s
.
getSubscriptionFromDB
(
ctx
,
userID
,
groupID
)
if
err
!=
nil
{
return
nil
,
err
}
// 异步建立缓存
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
s
.
setSubscriptionCache
(
cacheCtx
,
userID
,
groupID
,
data
)
}()
return
data
,
nil
}
// getSubscriptionFromDB 从数据库获取订阅数据
func
(
s
*
BillingCacheService
)
getSubscriptionFromDB
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
subscriptionCacheData
,
error
)
{
sub
,
err
:=
s
.
subRepo
.
GetActiveByUserIDAndGroupID
(
ctx
,
userID
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get subscription: %w"
,
err
)
}
return
&
subscriptionCacheData
{
Status
:
sub
.
Status
,
ExpiresAt
:
sub
.
ExpiresAt
,
DailyUsage
:
sub
.
DailyUsageUSD
,
WeeklyUsage
:
sub
.
WeeklyUsageUSD
,
MonthlyUsage
:
sub
.
MonthlyUsageUSD
,
Version
:
sub
.
UpdatedAt
.
Unix
(),
},
nil
}
// parseSubscriptionCache 解析订阅缓存数据
func
(
s
*
BillingCacheService
)
parseSubscriptionCache
(
data
map
[
string
]
string
)
(
*
subscriptionCacheData
,
error
)
{
result
:=
&
subscriptionCacheData
{}
result
.
Status
=
data
[
subFieldStatus
]
if
result
.
Status
==
""
{
return
nil
,
errors
.
New
(
"invalid cache: missing status"
)
}
if
expiresStr
,
ok
:=
data
[
subFieldExpiresAt
];
ok
{
expiresAt
,
err
:=
strconv
.
ParseInt
(
expiresStr
,
10
,
64
)
if
err
==
nil
{
result
.
ExpiresAt
=
time
.
Unix
(
expiresAt
,
0
)
}
}
if
dailyStr
,
ok
:=
data
[
subFieldDailyUsage
];
ok
{
result
.
DailyUsage
,
_
=
strconv
.
ParseFloat
(
dailyStr
,
64
)
}
if
weeklyStr
,
ok
:=
data
[
subFieldWeeklyUsage
];
ok
{
result
.
WeeklyUsage
,
_
=
strconv
.
ParseFloat
(
weeklyStr
,
64
)
}
if
monthlyStr
,
ok
:=
data
[
subFieldMonthlyUsage
];
ok
{
result
.
MonthlyUsage
,
_
=
strconv
.
ParseFloat
(
monthlyStr
,
64
)
}
if
versionStr
,
ok
:=
data
[
subFieldVersion
];
ok
{
result
.
Version
,
_
=
strconv
.
ParseInt
(
versionStr
,
10
,
64
)
}
return
result
,
nil
}
// setSubscriptionCache 设置订阅缓存
func
(
s
*
BillingCacheService
)
setSubscriptionCache
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
data
*
subscriptionCacheData
)
{
if
s
.
rdb
==
nil
||
data
==
nil
{
return
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
billingSubKeyPrefix
,
userID
,
groupID
)
fields
:=
map
[
string
]
interface
{}{
subFieldStatus
:
data
.
Status
,
subFieldExpiresAt
:
data
.
ExpiresAt
.
Unix
(),
subFieldDailyUsage
:
data
.
DailyUsage
,
subFieldWeeklyUsage
:
data
.
WeeklyUsage
,
subFieldMonthlyUsage
:
data
.
MonthlyUsage
,
subFieldVersion
:
data
.
Version
,
}
pipe
:=
s
.
rdb
.
Pipeline
()
pipe
.
HSet
(
ctx
,
key
,
fields
)
pipe
.
Expire
(
ctx
,
key
,
billingCacheTTL
)
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
log
.
Printf
(
"Warning: set subscription cache failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
}
}
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
func
(
s
*
BillingCacheService
)
UpdateSubscriptionUsage
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
costUSD
float64
)
error
{
if
s
.
rdb
==
nil
{
return
nil
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
billingSubKeyPrefix
,
userID
,
groupID
)
// 使用预编译的Lua脚本原子性增加用量,如果key不存在则忽略
_
,
err
:=
updateSubUsageScript
.
Run
(
ctx
,
s
.
rdb
,
[]
string
{
key
},
costUSD
,
int
(
billingCacheTTL
.
Seconds
()))
.
Result
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
log
.
Printf
(
"Warning: update subscription usage cache failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
}
return
nil
}
// InvalidateSubscription 失效指定订阅缓存
func
(
s
*
BillingCacheService
)
InvalidateSubscription
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
error
{
if
s
.
rdb
==
nil
{
return
nil
}
key
:=
fmt
.
Sprintf
(
"%s%d:%d"
,
billingSubKeyPrefix
,
userID
,
groupID
)
if
err
:=
s
.
rdb
.
Del
(
ctx
,
key
)
.
Err
();
err
!=
nil
{
log
.
Printf
(
"Warning: invalidate subscription cache failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
return
err
}
return
nil
}
// ============================================
// 统一检查方法
// ============================================
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
func
(
s
*
BillingCacheService
)
CheckBillingEligibility
(
ctx
context
.
Context
,
user
*
model
.
User
,
apiKey
*
model
.
ApiKey
,
group
*
model
.
Group
,
subscription
*
model
.
UserSubscription
)
error
{
// 判断计费模式
isSubscriptionMode
:=
group
!=
nil
&&
group
.
IsSubscriptionType
()
&&
subscription
!=
nil
if
isSubscriptionMode
{
return
s
.
checkSubscriptionEligibility
(
ctx
,
user
.
ID
,
group
,
subscription
)
}
return
s
.
checkBalanceEligibility
(
ctx
,
user
.
ID
)
}
// checkBalanceEligibility 检查余额模式资格
func
(
s
*
BillingCacheService
)
checkBalanceEligibility
(
ctx
context
.
Context
,
userID
int64
)
error
{
balance
,
err
:=
s
.
GetUserBalance
(
ctx
,
userID
)
if
err
!=
nil
{
// 缓存/数据库错误,允许通过(降级处理)
log
.
Printf
(
"Warning: get user balance failed, allowing request: %v"
,
err
)
return
nil
}
if
balance
<=
0
{
return
ErrInsufficientBalance
}
return
nil
}
// checkSubscriptionEligibility 检查订阅模式资格
func
(
s
*
BillingCacheService
)
checkSubscriptionEligibility
(
ctx
context
.
Context
,
userID
int64
,
group
*
model
.
Group
,
subscription
*
model
.
UserSubscription
)
error
{
// 获取订阅缓存数据
subData
,
err
:=
s
.
GetSubscriptionStatus
(
ctx
,
userID
,
group
.
ID
)
if
err
!=
nil
{
// 缓存/数据库错误,降级使用传入的subscription进行检查
log
.
Printf
(
"Warning: get subscription cache failed, using fallback: %v"
,
err
)
return
s
.
checkSubscriptionLimitsFallback
(
subscription
,
group
)
}
// 检查订阅状态
if
subData
.
Status
!=
model
.
SubscriptionStatusActive
{
return
ErrSubscriptionInvalid
}
// 检查是否过期
if
time
.
Now
()
.
After
(
subData
.
ExpiresAt
)
{
return
ErrSubscriptionInvalid
}
// 检查限额(使用传入的Group限额配置)
if
group
.
HasDailyLimit
()
&&
subData
.
DailyUsage
>=
*
group
.
DailyLimitUSD
{
return
ErrDailyLimitExceeded
}
if
group
.
HasWeeklyLimit
()
&&
subData
.
WeeklyUsage
>=
*
group
.
WeeklyLimitUSD
{
return
ErrWeeklyLimitExceeded
}
if
group
.
HasMonthlyLimit
()
&&
subData
.
MonthlyUsage
>=
*
group
.
MonthlyLimitUSD
{
return
ErrMonthlyLimitExceeded
}
return
nil
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func
(
s
*
BillingCacheService
)
checkSubscriptionLimitsFallback
(
subscription
*
model
.
UserSubscription
,
group
*
model
.
Group
)
error
{
if
subscription
==
nil
{
return
ErrSubscriptionInvalid
}
if
!
subscription
.
IsActive
()
{
return
ErrSubscriptionInvalid
}
if
!
subscription
.
CheckDailyLimit
(
group
,
0
)
{
return
ErrDailyLimitExceeded
}
if
!
subscription
.
CheckWeeklyLimit
(
group
,
0
)
{
return
ErrWeeklyLimitExceeded
}
if
!
subscription
.
CheckMonthlyLimit
(
group
,
0
)
{
return
ErrMonthlyLimitExceeded
}
return
nil
}
backend/internal/service/billing_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"fmt"
"log"
"strings"
"sub2api/internal/config"
)
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
type
ModelPricing
struct
{
InputPricePerToken
float64
// 每token输入价格 (USD)
OutputPricePerToken
float64
// 每token输出价格 (USD)
CacheCreationPricePerToken
float64
// 缓存创建每token价格 (USD)
CacheReadPricePerToken
float64
// 缓存读取每token价格 (USD)
CacheCreation5mPrice
float64
// 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
CacheCreation1hPrice
float64
// 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
SupportsCacheBreakdown
bool
// 是否支持详细的缓存分类
}
// UsageTokens 使用的token数量
type
UsageTokens
struct
{
InputTokens
int
OutputTokens
int
CacheCreationTokens
int
CacheReadTokens
int
CacheCreation5mTokens
int
CacheCreation1hTokens
int
}
// CostBreakdown 费用明细
type
CostBreakdown
struct
{
InputCost
float64
OutputCost
float64
CacheCreationCost
float64
CacheReadCost
float64
TotalCost
float64
ActualCost
float64
// 应用倍率后的实际费用
}
// BillingService 计费服务
type
BillingService
struct
{
cfg
*
config
.
Config
pricingService
*
PricingService
fallbackPrices
map
[
string
]
*
ModelPricing
// 硬编码回退价格
}
// NewBillingService 创建计费服务实例
func
NewBillingService
(
cfg
*
config
.
Config
,
pricingService
*
PricingService
)
*
BillingService
{
s
:=
&
BillingService
{
cfg
:
cfg
,
pricingService
:
pricingService
,
fallbackPrices
:
make
(
map
[
string
]
*
ModelPricing
),
}
// 初始化硬编码回退价格(当动态价格不可用时使用)
s
.
initFallbackPricing
()
return
s
}
// initFallbackPricing 初始化硬编码回退价格(当动态价格不可用时使用)
// 价格单位:USD per token(与LiteLLM格式一致)
func
(
s
*
BillingService
)
initFallbackPricing
()
{
// Claude 4.5 Opus
s
.
fallbackPrices
[
"claude-opus-4.5"
]
=
&
ModelPricing
{
InputPricePerToken
:
5e-6
,
// $5 per MTok
OutputPricePerToken
:
25e-6
,
// $25 per MTok
CacheCreationPricePerToken
:
6.25e-6
,
// $6.25 per MTok
CacheReadPricePerToken
:
0.5e-6
,
// $0.50 per MTok
SupportsCacheBreakdown
:
false
,
}
// Claude 4 Sonnet
s
.
fallbackPrices
[
"claude-sonnet-4"
]
=
&
ModelPricing
{
InputPricePerToken
:
3e-6
,
// $3 per MTok
OutputPricePerToken
:
15e-6
,
// $15 per MTok
CacheCreationPricePerToken
:
3.75e-6
,
// $3.75 per MTok
CacheReadPricePerToken
:
0.3e-6
,
// $0.30 per MTok
SupportsCacheBreakdown
:
false
,
}
// Claude 3.5 Sonnet
s
.
fallbackPrices
[
"claude-3-5-sonnet"
]
=
&
ModelPricing
{
InputPricePerToken
:
3e-6
,
// $3 per MTok
OutputPricePerToken
:
15e-6
,
// $15 per MTok
CacheCreationPricePerToken
:
3.75e-6
,
// $3.75 per MTok
CacheReadPricePerToken
:
0.3e-6
,
// $0.30 per MTok
SupportsCacheBreakdown
:
false
,
}
// Claude 3.5 Haiku
s
.
fallbackPrices
[
"claude-3-5-haiku"
]
=
&
ModelPricing
{
InputPricePerToken
:
1e-6
,
// $1 per MTok
OutputPricePerToken
:
5e-6
,
// $5 per MTok
CacheCreationPricePerToken
:
1.25e-6
,
// $1.25 per MTok
CacheReadPricePerToken
:
0.1e-6
,
// $0.10 per MTok
SupportsCacheBreakdown
:
false
,
}
// Claude 3 Opus
s
.
fallbackPrices
[
"claude-3-opus"
]
=
&
ModelPricing
{
InputPricePerToken
:
15e-6
,
// $15 per MTok
OutputPricePerToken
:
75e-6
,
// $75 per MTok
CacheCreationPricePerToken
:
18.75e-6
,
// $18.75 per MTok
CacheReadPricePerToken
:
1.5e-6
,
// $1.50 per MTok
SupportsCacheBreakdown
:
false
,
}
// Claude 3 Haiku
s
.
fallbackPrices
[
"claude-3-haiku"
]
=
&
ModelPricing
{
InputPricePerToken
:
0.25e-6
,
// $0.25 per MTok
OutputPricePerToken
:
1.25e-6
,
// $1.25 per MTok
CacheCreationPricePerToken
:
0.3e-6
,
// $0.30 per MTok
CacheReadPricePerToken
:
0.03e-6
,
// $0.03 per MTok
SupportsCacheBreakdown
:
false
,
}
}
// getFallbackPricing 根据模型系列获取回退价格
func
(
s
*
BillingService
)
getFallbackPricing
(
model
string
)
*
ModelPricing
{
modelLower
:=
strings
.
ToLower
(
model
)
// 按模型系列匹配
if
strings
.
Contains
(
modelLower
,
"opus"
)
{
if
strings
.
Contains
(
modelLower
,
"4.5"
)
||
strings
.
Contains
(
modelLower
,
"4-5"
)
{
return
s
.
fallbackPrices
[
"claude-opus-4.5"
]
}
return
s
.
fallbackPrices
[
"claude-3-opus"
]
}
if
strings
.
Contains
(
modelLower
,
"sonnet"
)
{
if
strings
.
Contains
(
modelLower
,
"4"
)
&&
!
strings
.
Contains
(
modelLower
,
"3"
)
{
return
s
.
fallbackPrices
[
"claude-sonnet-4"
]
}
return
s
.
fallbackPrices
[
"claude-3-5-sonnet"
]
}
if
strings
.
Contains
(
modelLower
,
"haiku"
)
{
if
strings
.
Contains
(
modelLower
,
"3-5"
)
||
strings
.
Contains
(
modelLower
,
"3.5"
)
{
return
s
.
fallbackPrices
[
"claude-3-5-haiku"
]
}
return
s
.
fallbackPrices
[
"claude-3-haiku"
]
}
// 默认使用Sonnet价格
return
s
.
fallbackPrices
[
"claude-sonnet-4"
]
}
// GetModelPricing 获取模型价格配置
func
(
s
*
BillingService
)
GetModelPricing
(
model
string
)
(
*
ModelPricing
,
error
)
{
// 标准化模型名称(转小写)
model
=
strings
.
ToLower
(
model
)
// 1. 优先从动态价格服务获取
if
s
.
pricingService
!=
nil
{
litellmPricing
:=
s
.
pricingService
.
GetModelPricing
(
model
)
if
litellmPricing
!=
nil
{
return
&
ModelPricing
{
InputPricePerToken
:
litellmPricing
.
InputCostPerToken
,
OutputPricePerToken
:
litellmPricing
.
OutputCostPerToken
,
CacheCreationPricePerToken
:
litellmPricing
.
CacheCreationInputTokenCost
,
CacheReadPricePerToken
:
litellmPricing
.
CacheReadInputTokenCost
,
SupportsCacheBreakdown
:
false
,
},
nil
}
}
// 2. 使用硬编码回退价格
fallback
:=
s
.
getFallbackPricing
(
model
)
if
fallback
!=
nil
{
log
.
Printf
(
"[Billing] Using fallback pricing for model: %s"
,
model
)
return
fallback
,
nil
}
return
nil
,
fmt
.
Errorf
(
"pricing not found for model: %s"
,
model
)
}
// CalculateCost 计算使用费用
func
(
s
*
BillingService
)
CalculateCost
(
model
string
,
tokens
UsageTokens
,
rateMultiplier
float64
)
(
*
CostBreakdown
,
error
)
{
pricing
,
err
:=
s
.
GetModelPricing
(
model
)
if
err
!=
nil
{
return
nil
,
err
}
breakdown
:=
&
CostBreakdown
{}
// 计算输入token费用(使用per-token价格)
breakdown
.
InputCost
=
float64
(
tokens
.
InputTokens
)
*
pricing
.
InputPricePerToken
// 计算输出token费用
breakdown
.
OutputCost
=
float64
(
tokens
.
OutputTokens
)
*
pricing
.
OutputPricePerToken
// 计算缓存费用
if
pricing
.
SupportsCacheBreakdown
&&
(
pricing
.
CacheCreation5mPrice
>
0
||
pricing
.
CacheCreation1hPrice
>
0
)
{
// 支持详细缓存分类的模型(5分钟/1小时缓存)
breakdown
.
CacheCreationCost
=
float64
(
tokens
.
CacheCreation5mTokens
)
/
1
_000_000
*
pricing
.
CacheCreation5mPrice
+
float64
(
tokens
.
CacheCreation1hTokens
)
/
1
_000_000
*
pricing
.
CacheCreation1hPrice
}
else
{
// 标准缓存创建价格(per-token)
breakdown
.
CacheCreationCost
=
float64
(
tokens
.
CacheCreationTokens
)
*
pricing
.
CacheCreationPricePerToken
}
breakdown
.
CacheReadCost
=
float64
(
tokens
.
CacheReadTokens
)
*
pricing
.
CacheReadPricePerToken
// 计算总费用
breakdown
.
TotalCost
=
breakdown
.
InputCost
+
breakdown
.
OutputCost
+
breakdown
.
CacheCreationCost
+
breakdown
.
CacheReadCost
// 应用倍率计算实际费用
if
rateMultiplier
<=
0
{
rateMultiplier
=
1.0
}
breakdown
.
ActualCost
=
breakdown
.
TotalCost
*
rateMultiplier
return
breakdown
,
nil
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
func
(
s
*
BillingService
)
CalculateCostWithConfig
(
model
string
,
tokens
UsageTokens
)
(
*
CostBreakdown
,
error
)
{
multiplier
:=
s
.
cfg
.
Default
.
RateMultiplier
if
multiplier
<=
0
{
multiplier
=
1.0
}
return
s
.
CalculateCost
(
model
,
tokens
,
multiplier
)
}
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
func
(
s
*
BillingService
)
ListSupportedModels
()
[]
string
{
models
:=
make
([]
string
,
0
)
// 返回回退价格支持的模型系列
for
model
:=
range
s
.
fallbackPrices
{
models
=
append
(
models
,
model
)
}
return
models
}
// IsModelSupported 检查模型是否支持(现在总是返回true,因为有模糊匹配回退)
func
(
s
*
BillingService
)
IsModelSupported
(
model
string
)
bool
{
// 所有Claude模型都有回退价格支持
modelLower
:=
strings
.
ToLower
(
model
)
return
strings
.
Contains
(
modelLower
,
"claude"
)
||
strings
.
Contains
(
modelLower
,
"opus"
)
||
strings
.
Contains
(
modelLower
,
"sonnet"
)
||
strings
.
Contains
(
modelLower
,
"haiku"
)
}
// GetEstimatedCost 估算费用(用于前端展示)
func
(
s
*
BillingService
)
GetEstimatedCost
(
model
string
,
estimatedInputTokens
,
estimatedOutputTokens
int
)
(
float64
,
error
)
{
tokens
:=
UsageTokens
{
InputTokens
:
estimatedInputTokens
,
OutputTokens
:
estimatedOutputTokens
,
}
breakdown
,
err
:=
s
.
CalculateCostWithConfig
(
model
,
tokens
)
if
err
!=
nil
{
return
0
,
err
}
return
breakdown
.
ActualCost
,
nil
}
// GetPricingServiceStatus 获取价格服务状态
func
(
s
*
BillingService
)
GetPricingServiceStatus
()
map
[
string
]
interface
{}
{
if
s
.
pricingService
!=
nil
{
return
s
.
pricingService
.
GetStatus
()
}
return
map
[
string
]
interface
{}{
"model_count"
:
len
(
s
.
fallbackPrices
),
"last_updated"
:
"using fallback"
,
"local_hash"
:
"N/A"
,
}
}
// ForceUpdatePricing 强制更新价格数据
func
(
s
*
BillingService
)
ForceUpdatePricing
()
error
{
if
s
.
pricingService
!=
nil
{
return
s
.
pricingService
.
ForceUpdate
()
}
return
fmt
.
Errorf
(
"pricing service not initialized"
)
}
backend/internal/service/concurrency_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"fmt"
"log"
"time"
"github.com/redis/go-redis/v9"
)
const
(
// Redis key prefixes
accountConcurrencyKey
=
"concurrency:account:"
userConcurrencyKey
=
"concurrency:user:"
userWaitCountKey
=
"concurrency:wait:"
// TTL for concurrency keys (auto-release safety net)
concurrencyKeyTTL
=
10
*
time
.
Minute
// Wait polling interval
waitPollInterval
=
100
*
time
.
Millisecond
// Default max wait time
defaultMaxWait
=
60
*
time
.
Second
// Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots
=
20
)
// Pre-compiled Lua scripts for better performance
var
(
// acquireScript: increment counter if below max, return 1 if successful
acquireScript
=
redis
.
NewScript
(
`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
`
)
// releaseScript: decrement counter, but don't go below 0
releaseScript
=
redis
.
NewScript
(
`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`
)
// incrementWaitScript: increment wait counter if below max, return 1 if successful
incrementWaitScript
=
redis
.
NewScript
(
`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
return 1
`
)
// decrementWaitScript: decrement wait counter, but don't go below 0
decrementWaitScript
=
redis
.
NewScript
(
`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`
)
)
// ConcurrencyService manages concurrent request limiting for accounts and users
type
ConcurrencyService
struct
{
rdb
*
redis
.
Client
}
// NewConcurrencyService creates a new ConcurrencyService
func
NewConcurrencyService
(
rdb
*
redis
.
Client
)
*
ConcurrencyService
{
return
&
ConcurrencyService
{
rdb
:
rdb
}
}
// AcquireResult represents the result of acquiring a concurrency slot
type
AcquireResult
struct
{
Acquired
bool
ReleaseFunc
func
()
// Must be called when done (typically via defer)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func
(
s
*
ConcurrencyService
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
)
(
*
AcquireResult
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
accountConcurrencyKey
,
accountID
)
return
s
.
acquireSlot
(
ctx
,
key
,
maxConcurrency
)
}
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
// If the user is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func
(
s
*
ConcurrencyService
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
)
(
*
AcquireResult
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
userConcurrencyKey
,
userID
)
return
s
.
acquireSlot
(
ctx
,
key
,
maxConcurrency
)
}
// acquireSlot is the core implementation for acquiring a concurrency slot
func
(
s
*
ConcurrencyService
)
acquireSlot
(
ctx
context
.
Context
,
key
string
,
maxConcurrency
int
)
(
*
AcquireResult
,
error
)
{
// If maxConcurrency is 0 or negative, no limit
if
maxConcurrency
<=
0
{
return
&
AcquireResult
{
Acquired
:
true
,
ReleaseFunc
:
func
()
{},
// no-op
},
nil
}
// Try to acquire immediately
acquired
,
err
:=
s
.
tryAcquire
(
ctx
,
key
,
maxConcurrency
)
if
err
!=
nil
{
return
nil
,
err
}
if
acquired
{
return
&
AcquireResult
{
Acquired
:
true
,
ReleaseFunc
:
s
.
makeReleaseFunc
(
key
),
},
nil
}
// Not acquired, return with Acquired=false
// The caller (gateway handler) will handle waiting with ping support
return
&
AcquireResult
{
Acquired
:
false
,
ReleaseFunc
:
nil
,
},
nil
}
// tryAcquire attempts to increment the counter if below max
// Uses pre-compiled Lua script for atomicity and performance
func
(
s
*
ConcurrencyService
)
tryAcquire
(
ctx
context
.
Context
,
key
string
,
maxConcurrency
int
)
(
bool
,
error
)
{
result
,
err
:=
acquireScript
.
Run
(
ctx
,
s
.
rdb
,
[]
string
{
key
},
maxConcurrency
,
int
(
concurrencyKeyTTL
.
Seconds
()))
.
Int
()
if
err
!=
nil
{
return
false
,
fmt
.
Errorf
(
"acquire slot failed: %w"
,
err
)
}
return
result
==
1
,
nil
}
// makeReleaseFunc creates a function to release a concurrency slot
func
(
s
*
ConcurrencyService
)
makeReleaseFunc
(
key
string
)
func
()
{
return
func
()
{
// Use background context to ensure release even if original context is cancelled
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
err
:=
releaseScript
.
Run
(
ctx
,
s
.
rdb
,
[]
string
{
key
})
.
Err
();
err
!=
nil
{
// Log error but don't panic - TTL will eventually clean up
log
.
Printf
(
"Warning: failed to release concurrency slot for %s: %v"
,
key
,
err
)
}
}
}
// GetCurrentCount returns the current concurrency count for debugging/monitoring
func
(
s
*
ConcurrencyService
)
GetCurrentCount
(
ctx
context
.
Context
,
key
string
)
(
int
,
error
)
{
val
,
err
:=
s
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
if
err
!=
nil
{
return
0
,
err
}
return
val
,
nil
}
// GetAccountCurrentCount returns current concurrency count for an account
func
(
s
*
ConcurrencyService
)
GetAccountCurrentCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
accountConcurrencyKey
,
accountID
)
return
s
.
GetCurrentCount
(
ctx
,
key
)
}
// GetUserCurrentCount returns current concurrency count for a user
func
(
s
*
ConcurrencyService
)
GetUserCurrentCount
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
userConcurrencyKey
,
userID
)
return
s
.
GetCurrentCount
(
ctx
,
key
)
}
// ============================================
// Wait Queue Count Methods
// ============================================
// IncrementWaitCount attempts to increment the wait queue counter for a user.
// Returns true if successful, false if the wait queue is full.
// maxWait should be user.Concurrency + defaultExtraWaitSlots
func
(
s
*
ConcurrencyService
)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
{
if
s
.
rdb
==
nil
{
// Redis not available, allow request
return
true
,
nil
}
key
:=
fmt
.
Sprintf
(
"%s%d"
,
userWaitCountKey
,
userID
)
result
,
err
:=
incrementWaitScript
.
Run
(
ctx
,
s
.
rdb
,
[]
string
{
key
},
maxWait
,
int
(
concurrencyKeyTTL
.
Seconds
()))
.
Int
()
if
err
!=
nil
{
// On error, allow the request to proceed (fail open)
log
.
Printf
(
"Warning: increment wait count failed for user %d: %v"
,
userID
,
err
)
return
true
,
nil
}
return
result
==
1
,
nil
}
// DecrementWaitCount decrements the wait queue counter for a user.
// Should be called when a request completes or exits the wait queue.
func
(
s
*
ConcurrencyService
)
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
{
if
s
.
rdb
==
nil
{
return
}
key
:=
fmt
.
Sprintf
(
"%s%d"
,
userWaitCountKey
,
userID
)
// Use background context to ensure decrement even if original context is cancelled
bgCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
err
:=
decrementWaitScript
.
Run
(
bgCtx
,
s
.
rdb
,
[]
string
{
key
})
.
Err
();
err
!=
nil
{
log
.
Printf
(
"Warning: decrement wait count failed for user %d: %v"
,
userID
,
err
)
}
}
// GetUserWaitCount returns current wait queue count for a user
func
(
s
*
ConcurrencyService
)
GetUserWaitCount
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%d"
,
userWaitCountKey
,
userID
)
return
s
.
GetCurrentCount
(
ctx
,
key
)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func
CalculateMaxWait
(
userConcurrency
int
)
int
{
if
userConcurrency
<=
0
{
userConcurrency
=
1
}
return
userConcurrency
+
defaultExtraWaitSlots
}
backend/internal/service/email_queue_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"fmt"
"log"
"sync"
"time"
)
// EmailTask 邮件发送任务
type
EmailTask
struct
{
Email
string
SiteName
string
TaskType
string
// "verify_code"
}
// EmailQueueService 异步邮件队列服务
type
EmailQueueService
struct
{
emailService
*
EmailService
taskChan
chan
EmailTask
wg
sync
.
WaitGroup
stopChan
chan
struct
{}
workers
int
}
// NewEmailQueueService 创建邮件队列服务
func
NewEmailQueueService
(
emailService
*
EmailService
,
workers
int
)
*
EmailQueueService
{
if
workers
<=
0
{
workers
=
3
// 默认3个工作协程
}
service
:=
&
EmailQueueService
{
emailService
:
emailService
,
taskChan
:
make
(
chan
EmailTask
,
100
),
// 缓冲100个任务
stopChan
:
make
(
chan
struct
{}),
workers
:
workers
,
}
// 启动工作协程
service
.
start
()
return
service
}
// start 启动工作协程
func
(
s
*
EmailQueueService
)
start
()
{
for
i
:=
0
;
i
<
s
.
workers
;
i
++
{
s
.
wg
.
Add
(
1
)
go
s
.
worker
(
i
)
}
log
.
Printf
(
"[EmailQueue] Started %d workers"
,
s
.
workers
)
}
// worker 工作协程
func
(
s
*
EmailQueueService
)
worker
(
id
int
)
{
defer
s
.
wg
.
Done
()
for
{
select
{
case
task
:=
<-
s
.
taskChan
:
s
.
processTask
(
id
,
task
)
case
<-
s
.
stopChan
:
log
.
Printf
(
"[EmailQueue] Worker %d stopping"
,
id
)
return
}
}
}
// processTask 处理任务
func
(
s
*
EmailQueueService
)
processTask
(
workerID
int
,
task
EmailTask
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
switch
task
.
TaskType
{
case
"verify_code"
:
if
err
:=
s
.
emailService
.
SendVerifyCode
(
ctx
,
task
.
Email
,
task
.
SiteName
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send verify code to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent verify code to %s"
,
workerID
,
task
.
Email
)
}
default
:
log
.
Printf
(
"[EmailQueue] Worker %d unknown task type: %s"
,
workerID
,
task
.
TaskType
)
}
}
// EnqueueVerifyCode 将验证码发送任务加入队列
func
(
s
*
EmailQueueService
)
EnqueueVerifyCode
(
email
,
siteName
string
)
error
{
task
:=
EmailTask
{
Email
:
email
,
SiteName
:
siteName
,
TaskType
:
"verify_code"
,
}
select
{
case
s
.
taskChan
<-
task
:
log
.
Printf
(
"[EmailQueue] Enqueued verify code task for %s"
,
email
)
return
nil
default
:
return
fmt
.
Errorf
(
"email queue is full"
)
}
}
// Stop 停止队列服务
func
(
s
*
EmailQueueService
)
Stop
()
{
close
(
s
.
stopChan
)
s
.
wg
.
Wait
()
log
.
Println
(
"[EmailQueue] All workers stopped"
)
}
backend/internal/service/email_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"crypto/rand"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/smtp"
"strconv"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
)
var
(
ErrEmailNotConfigured
=
errors
.
New
(
"email service not configured"
)
ErrInvalidVerifyCode
=
errors
.
New
(
"invalid or expired verification code"
)
ErrVerifyCodeTooFrequent
=
errors
.
New
(
"please wait before requesting a new code"
)
ErrVerifyCodeMaxAttempts
=
errors
.
New
(
"too many failed attempts, please request a new code"
)
)
const
(
verifyCodeKeyPrefix
=
"email_verify:"
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
maxVerifyCodeAttempts
=
5
)
// verifyCodeData Redis 中存储的验证码数据
type
verifyCodeData
struct
{
Code
string
`json:"code"`
Attempts
int
`json:"attempts"`
CreatedAt
time
.
Time
`json:"created_at"`
}
// SmtpConfig SMTP配置
type
SmtpConfig
struct
{
Host
string
Port
int
Username
string
Password
string
From
string
FromName
string
UseTLS
bool
}
// EmailService 邮件服务
type
EmailService
struct
{
settingRepo
*
repository
.
SettingRepository
rdb
*
redis
.
Client
}
// NewEmailService 创建邮件服务实例
func
NewEmailService
(
settingRepo
*
repository
.
SettingRepository
,
rdb
*
redis
.
Client
)
*
EmailService
{
return
&
EmailService
{
settingRepo
:
settingRepo
,
rdb
:
rdb
,
}
}
// GetSmtpConfig 从数据库获取SMTP配置
func
(
s
*
EmailService
)
GetSmtpConfig
(
ctx
context
.
Context
)
(
*
SmtpConfig
,
error
)
{
keys
:=
[]
string
{
model
.
SettingKeySmtpHost
,
model
.
SettingKeySmtpPort
,
model
.
SettingKeySmtpUsername
,
model
.
SettingKeySmtpPassword
,
model
.
SettingKeySmtpFrom
,
model
.
SettingKeySmtpFromName
,
model
.
SettingKeySmtpUseTLS
,
}
settings
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get smtp settings: %w"
,
err
)
}
host
:=
settings
[
model
.
SettingKeySmtpHost
]
if
host
==
""
{
return
nil
,
ErrEmailNotConfigured
}
port
:=
587
// 默认端口
if
portStr
:=
settings
[
model
.
SettingKeySmtpPort
];
portStr
!=
""
{
if
p
,
err
:=
strconv
.
Atoi
(
portStr
);
err
==
nil
{
port
=
p
}
}
useTLS
:=
settings
[
model
.
SettingKeySmtpUseTLS
]
==
"true"
return
&
SmtpConfig
{
Host
:
host
,
Port
:
port
,
Username
:
settings
[
model
.
SettingKeySmtpUsername
],
Password
:
settings
[
model
.
SettingKeySmtpPassword
],
From
:
settings
[
model
.
SettingKeySmtpFrom
],
FromName
:
settings
[
model
.
SettingKeySmtpFromName
],
UseTLS
:
useTLS
,
},
nil
}
// SendEmail 发送邮件(使用数据库中保存的配置)
func
(
s
*
EmailService
)
SendEmail
(
ctx
context
.
Context
,
to
,
subject
,
body
string
)
error
{
config
,
err
:=
s
.
GetSmtpConfig
(
ctx
)
if
err
!=
nil
{
return
err
}
return
s
.
SendEmailWithConfig
(
config
,
to
,
subject
,
body
)
}
// SendEmailWithConfig 使用指定配置发送邮件
func
(
s
*
EmailService
)
SendEmailWithConfig
(
config
*
SmtpConfig
,
to
,
subject
,
body
string
)
error
{
from
:=
config
.
From
if
config
.
FromName
!=
""
{
from
=
fmt
.
Sprintf
(
"%s <%s>"
,
config
.
FromName
,
config
.
From
)
}
msg
:=
fmt
.
Sprintf
(
"From: %s
\r\n
To: %s
\r\n
Subject: %s
\r\n
MIME-Version: 1.0
\r\n
Content-Type: text/html; charset=UTF-8
\r\n\r\n
%s"
,
from
,
to
,
subject
,
body
)
addr
:=
fmt
.
Sprintf
(
"%s:%d"
,
config
.
Host
,
config
.
Port
)
auth
:=
smtp
.
PlainAuth
(
""
,
config
.
Username
,
config
.
Password
,
config
.
Host
)
if
config
.
UseTLS
{
return
s
.
sendMailTLS
(
addr
,
auth
,
config
.
From
,
to
,
[]
byte
(
msg
),
config
.
Host
)
}
return
smtp
.
SendMail
(
addr
,
auth
,
config
.
From
,
[]
string
{
to
},
[]
byte
(
msg
))
}
// sendMailTLS 使用TLS发送邮件
func
(
s
*
EmailService
)
sendMailTLS
(
addr
string
,
auth
smtp
.
Auth
,
from
,
to
string
,
msg
[]
byte
,
host
string
)
error
{
tlsConfig
:=
&
tls
.
Config
{
ServerName
:
host
,
}
conn
,
err
:=
tls
.
Dial
(
"tcp"
,
addr
,
tlsConfig
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"tls dial: %w"
,
err
)
}
defer
conn
.
Close
()
client
,
err
:=
smtp
.
NewClient
(
conn
,
host
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"new smtp client: %w"
,
err
)
}
defer
client
.
Close
()
if
err
=
client
.
Auth
(
auth
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp auth: %w"
,
err
)
}
if
err
=
client
.
Mail
(
from
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp mail: %w"
,
err
)
}
if
err
=
client
.
Rcpt
(
to
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp rcpt: %w"
,
err
)
}
w
,
err
:=
client
.
Data
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp data: %w"
,
err
)
}
_
,
err
=
w
.
Write
(
msg
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"write msg: %w"
,
err
)
}
err
=
w
.
Close
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"close writer: %w"
,
err
)
}
// Email is sent successfully after w.Close(), ignore Quit errors
// Some SMTP servers return non-standard responses on QUIT
_
=
client
.
Quit
()
return
nil
}
// GenerateVerifyCode 生成6位数字验证码
func
(
s
*
EmailService
)
GenerateVerifyCode
()
(
string
,
error
)
{
const
digits
=
"0123456789"
code
:=
make
([]
byte
,
6
)
for
i
:=
range
code
{
num
,
err
:=
rand
.
Int
(
rand
.
Reader
,
big
.
NewInt
(
int64
(
len
(
digits
))))
if
err
!=
nil
{
return
""
,
err
}
code
[
i
]
=
digits
[
num
.
Int64
()]
}
return
string
(
code
),
nil
}
// SendVerifyCode 发送验证码邮件
func
(
s
*
EmailService
)
SendVerifyCode
(
ctx
context
.
Context
,
email
,
siteName
string
)
error
{
key
:=
verifyCodeKeyPrefix
+
email
// 检查是否在冷却期内
existing
,
err
:=
s
.
getVerifyCodeData
(
ctx
,
key
)
if
err
==
nil
&&
existing
!=
nil
{
if
time
.
Since
(
existing
.
CreatedAt
)
<
verifyCodeCooldown
{
return
ErrVerifyCodeTooFrequent
}
}
// 生成验证码
code
,
err
:=
s
.
GenerateVerifyCode
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"generate code: %w"
,
err
)
}
// 保存验证码到 Redis
data
:=
&
verifyCodeData
{
Code
:
code
,
Attempts
:
0
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
setVerifyCodeData
(
ctx
,
key
,
data
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save verify code: %w"
,
err
)
}
// 构建邮件内容
subject
:=
fmt
.
Sprintf
(
"[%s] Email Verification Code"
,
siteName
)
body
:=
s
.
buildVerifyCodeEmailBody
(
code
,
siteName
)
// 发送邮件
if
err
:=
s
.
SendEmail
(
ctx
,
email
,
subject
,
body
);
err
!=
nil
{
return
fmt
.
Errorf
(
"send email: %w"
,
err
)
}
return
nil
}
// VerifyCode 验证验证码
func
(
s
*
EmailService
)
VerifyCode
(
ctx
context
.
Context
,
email
,
code
string
)
error
{
key
:=
verifyCodeKeyPrefix
+
email
data
,
err
:=
s
.
getVerifyCodeData
(
ctx
,
key
)
if
err
!=
nil
||
data
==
nil
{
return
ErrInvalidVerifyCode
}
// 检查是否已达到最大尝试次数
if
data
.
Attempts
>=
maxVerifyCodeAttempts
{
return
ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if
data
.
Code
!=
code
{
data
.
Attempts
++
_
=
s
.
setVerifyCodeData
(
ctx
,
key
,
data
)
if
data
.
Attempts
>=
maxVerifyCodeAttempts
{
return
ErrVerifyCodeMaxAttempts
}
return
ErrInvalidVerifyCode
}
// 验证成功,删除验证码
s
.
rdb
.
Del
(
ctx
,
key
)
return
nil
}
// getVerifyCodeData 从 Redis 获取验证码数据
func
(
s
*
EmailService
)
getVerifyCodeData
(
ctx
context
.
Context
,
key
string
)
(
*
verifyCodeData
,
error
)
{
val
,
err
:=
s
.
rdb
.
Get
(
ctx
,
key
)
.
Result
()
if
err
!=
nil
{
return
nil
,
err
}
var
data
verifyCodeData
if
err
:=
json
.
Unmarshal
([]
byte
(
val
),
&
data
);
err
!=
nil
{
return
nil
,
err
}
return
&
data
,
nil
}
// setVerifyCodeData 保存验证码数据到 Redis
func
(
s
*
EmailService
)
setVerifyCodeData
(
ctx
context
.
Context
,
key
string
,
data
*
verifyCodeData
)
error
{
val
,
err
:=
json
.
Marshal
(
data
)
if
err
!=
nil
{
return
err
}
return
s
.
rdb
.
Set
(
ctx
,
key
,
val
,
verifyCodeTTL
)
.
Err
()
}
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
func
(
s
*
EmailService
)
buildVerifyCodeEmailBody
(
code
,
siteName
string
)
string
{
return
fmt
.
Sprintf
(
`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.code { font-size: 36px; font-weight: bold; letter-spacing: 8px; color: #333; background-color: #f8f9fa; padding: 20px 30px; border-radius: 8px; display: inline-block; margin: 20px 0; font-family: monospace; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">Your verification code is:</p>
<div class="code">%s</div>
<div class="info">
<p>This code will expire in <strong>15 minutes</strong>.</p>
<p>If you did not request this code, please ignore this email.</p>
</div>
</div>
<div class="footer">
<p>This is an automated message, please do not reply.</p>
</div>
</div>
</body>
</html>
`
,
siteName
,
code
)
}
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func
(
s
*
EmailService
)
TestSmtpConnectionWithConfig
(
config
*
SmtpConfig
)
error
{
addr
:=
fmt
.
Sprintf
(
"%s:%d"
,
config
.
Host
,
config
.
Port
)
if
config
.
UseTLS
{
tlsConfig
:=
&
tls
.
Config
{
ServerName
:
config
.
Host
}
conn
,
err
:=
tls
.
Dial
(
"tcp"
,
addr
,
tlsConfig
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"tls connection failed: %w"
,
err
)
}
defer
conn
.
Close
()
client
,
err
:=
smtp
.
NewClient
(
conn
,
config
.
Host
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp client creation failed: %w"
,
err
)
}
defer
client
.
Close
()
auth
:=
smtp
.
PlainAuth
(
""
,
config
.
Username
,
config
.
Password
,
config
.
Host
)
if
err
=
client
.
Auth
(
auth
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp authentication failed: %w"
,
err
)
}
return
client
.
Quit
()
}
// 非TLS连接测试
client
,
err
:=
smtp
.
Dial
(
addr
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp connection failed: %w"
,
err
)
}
defer
client
.
Close
()
auth
:=
smtp
.
PlainAuth
(
""
,
config
.
Username
,
config
.
Password
,
config
.
Host
)
if
err
=
client
.
Auth
(
auth
);
err
!=
nil
{
return
fmt
.
Errorf
(
"smtp authentication failed: %w"
,
err
)
}
return
client
.
Quit
()
}
backend/internal/service/gateway_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
const
(
claudeAPIURL
=
"https://api.anthropic.com/v1/messages?beta=true"
stickySessionPrefix
=
"sticky_session:"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
tokenRefreshBuffer
=
5
*
60
// 提前5分钟刷新token
)
// allowedHeaders 白名单headers(参考CRS项目)
var
allowedHeaders
=
map
[
string
]
bool
{
"accept"
:
true
,
"x-stainless-retry-count"
:
true
,
"x-stainless-timeout"
:
true
,
"x-stainless-lang"
:
true
,
"x-stainless-package-version"
:
true
,
"x-stainless-os"
:
true
,
"x-stainless-arch"
:
true
,
"x-stainless-runtime"
:
true
,
"x-stainless-runtime-version"
:
true
,
"x-stainless-helper-method"
:
true
,
"anthropic-dangerous-direct-browser-access"
:
true
,
"anthropic-version"
:
true
,
"x-app"
:
true
,
"anthropic-beta"
:
true
,
"accept-language"
:
true
,
"sec-fetch-mode"
:
true
,
"accept-encoding"
:
true
,
"user-agent"
:
true
,
"content-type"
:
true
,
}
// ClaudeUsage 表示Claude API返回的usage信息
type
ClaudeUsage
struct
{
InputTokens
int
`json:"input_tokens"`
OutputTokens
int
`json:"output_tokens"`
CacheCreationInputTokens
int
`json:"cache_creation_input_tokens"`
CacheReadInputTokens
int
`json:"cache_read_input_tokens"`
}
// ForwardResult 转发结果
type
ForwardResult
struct
{
RequestID
string
Usage
ClaudeUsage
Model
string
Stream
bool
Duration
time
.
Duration
FirstTokenMs
*
int
// 首字时间(流式请求)
}
// GatewayService handles API gateway operations
type
GatewayService
struct
{
repos
*
repository
.
Repositories
rdb
*
redis
.
Client
cfg
*
config
.
Config
oauthService
*
OAuthService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
identityService
*
IdentityService
httpClient
*
http
.
Client
}
// NewGatewayService creates a new GatewayService
func
NewGatewayService
(
repos
*
repository
.
Repositories
,
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
,
oauthService
*
OAuthService
,
billingService
*
BillingService
,
rateLimitService
*
RateLimitService
,
billingCacheService
*
BillingCacheService
,
identityService
*
IdentityService
)
*
GatewayService
{
// 计算响应头超时时间
responseHeaderTimeout
:=
time
.
Duration
(
cfg
.
Gateway
.
ResponseHeaderTimeout
)
*
time
.
Second
if
responseHeaderTimeout
==
0
{
responseHeaderTimeout
=
300
*
time
.
Second
// 默认5分钟,LLM高负载时可能排队较久
}
transport
:=
&
http
.
Transport
{
MaxIdleConns
:
100
,
MaxIdleConnsPerHost
:
10
,
IdleConnTimeout
:
90
*
time
.
Second
,
ResponseHeaderTimeout
:
responseHeaderTimeout
,
// 等待上游响应头的超时
// 注意:不设置整体 Timeout,让流式响应可以无限时间传输
}
return
&
GatewayService
{
repos
:
repos
,
rdb
:
rdb
,
cfg
:
cfg
,
oauthService
:
oauthService
,
billingService
:
billingService
,
rateLimitService
:
rateLimitService
,
billingCacheService
:
billingCacheService
,
identityService
:
identityService
,
httpClient
:
&
http
.
Client
{
Transport
:
transport
,
// 不设置 Timeout:流式请求可能持续十几分钟
// 超时控制由 Transport.ResponseHeaderTimeout 负责(只控制等待响应头)
},
}
}
// GenerateSessionHash 从请求体计算粘性会话hash
func
(
s
*
GatewayService
)
GenerateSessionHash
(
body
[]
byte
)
string
{
var
req
map
[
string
]
interface
{}
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
""
}
// 1. 最高优先级:从metadata.user_id提取session_xxx
if
metadata
,
ok
:=
req
[
"metadata"
]
.
(
map
[
string
]
interface
{});
ok
{
if
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
);
ok
{
re
:=
regexp
.
MustCompile
(
`session_([a-f0-9-]{36})`
)
if
match
:=
re
.
FindStringSubmatch
(
userID
);
len
(
match
)
>
1
{
return
match
[
1
]
}
}
}
// 2. 提取带cache_control: {type: "ephemeral"}的内容
cacheableContent
:=
s
.
extractCacheableContent
(
req
)
if
cacheableContent
!=
""
{
return
s
.
hashContent
(
cacheableContent
)
}
// 3. Fallback: 使用system内容
if
system
:=
req
[
"system"
];
system
!=
nil
{
systemText
:=
s
.
extractTextFromSystem
(
system
)
if
systemText
!=
""
{
return
s
.
hashContent
(
systemText
)
}
}
// 4. 最后fallback: 使用第一条消息
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
interface
{});
ok
&&
len
(
messages
)
>
0
{
if
firstMsg
,
ok
:=
messages
[
0
]
.
(
map
[
string
]
interface
{});
ok
{
msgText
:=
s
.
extractTextFromContent
(
firstMsg
[
"content"
])
if
msgText
!=
""
{
return
s
.
hashContent
(
msgText
)
}
}
}
return
""
}
func
(
s
*
GatewayService
)
extractCacheableContent
(
req
map
[
string
]
interface
{})
string
{
var
content
string
// 检查system中的cacheable内容
if
system
,
ok
:=
req
[
"system"
]
.
([]
interface
{});
ok
{
for
_
,
part
:=
range
system
{
if
partMap
,
ok
:=
part
.
(
map
[
string
]
interface
{});
ok
{
if
cc
,
ok
:=
partMap
[
"cache_control"
]
.
(
map
[
string
]
interface
{});
ok
{
if
cc
[
"type"
]
==
"ephemeral"
{
if
text
,
ok
:=
partMap
[
"text"
]
.
(
string
);
ok
{
content
+=
text
}
}
}
}
}
}
// 检查messages中的cacheable内容
if
messages
,
ok
:=
req
[
"messages"
]
.
([]
interface
{});
ok
{
for
_
,
msg
:=
range
messages
{
if
msgMap
,
ok
:=
msg
.
(
map
[
string
]
interface
{});
ok
{
if
msgContent
,
ok
:=
msgMap
[
"content"
]
.
([]
interface
{});
ok
{
for
_
,
part
:=
range
msgContent
{
if
partMap
,
ok
:=
part
.
(
map
[
string
]
interface
{});
ok
{
if
cc
,
ok
:=
partMap
[
"cache_control"
]
.
(
map
[
string
]
interface
{});
ok
{
if
cc
[
"type"
]
==
"ephemeral"
{
// 找到cacheable内容,提取第一条消息的文本
return
s
.
extractTextFromContent
(
msgMap
[
"content"
])
}
}
}
}
}
}
}
}
return
content
}
func
(
s
*
GatewayService
)
extractTextFromSystem
(
system
interface
{})
string
{
switch
v
:=
system
.
(
type
)
{
case
string
:
return
v
case
[]
interface
{}
:
var
texts
[]
string
for
_
,
part
:=
range
v
{
if
partMap
,
ok
:=
part
.
(
map
[
string
]
interface
{});
ok
{
if
text
,
ok
:=
partMap
[
"text"
]
.
(
string
);
ok
{
texts
=
append
(
texts
,
text
)
}
}
}
return
strings
.
Join
(
texts
,
""
)
}
return
""
}
func
(
s
*
GatewayService
)
extractTextFromContent
(
content
interface
{})
string
{
switch
v
:=
content
.
(
type
)
{
case
string
:
return
v
case
[]
interface
{}
:
var
texts
[]
string
for
_
,
part
:=
range
v
{
if
partMap
,
ok
:=
part
.
(
map
[
string
]
interface
{});
ok
{
if
partMap
[
"type"
]
==
"text"
{
if
text
,
ok
:=
partMap
[
"text"
]
.
(
string
);
ok
{
texts
=
append
(
texts
,
text
)
}
}
}
}
return
strings
.
Join
(
texts
,
""
)
}
return
""
}
func
(
s
*
GatewayService
)
hashContent
(
content
string
)
string
{
hash
:=
sha256
.
Sum256
([]
byte
(
content
))
return
hex
.
EncodeToString
(
hash
[
:
16
])
// 32字符
}
// replaceModelInBody 替换请求体中的model字段
func
(
s
*
GatewayService
)
replaceModelInBody
(
body
[]
byte
,
newModel
string
)
[]
byte
{
var
req
map
[
string
]
interface
{}
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
}
req
[
"model"
]
=
newModel
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
}
return
newBody
}
// SelectAccount 选择账号(粘性会话+优先级)
func
(
s
*
GatewayService
)
SelectAccount
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
*
model
.
Account
,
error
)
{
return
s
.
SelectAccountForModel
(
ctx
,
groupID
,
sessionHash
,
""
)
}
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func
(
s
*
GatewayService
)
SelectAccountForModel
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
)
(
*
model
.
Account
,
error
)
{
// 1. 查询粘性会话
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
rdb
.
Get
(
ctx
,
stickySessionPrefix
+
sessionHash
)
.
Int64
()
if
err
==
nil
&&
accountID
>
0
{
account
,
err
:=
s
.
repos
.
Account
.
GetByID
(
ctx
,
accountID
)
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
// 同时检查模型支持
if
err
==
nil
&&
account
.
IsSchedulable
()
&&
(
requestedModel
==
""
||
account
.
IsModelSupported
(
requestedModel
))
{
// 续期粘性会话
s
.
rdb
.
Expire
(
ctx
,
stickySessionPrefix
+
sessionHash
,
stickySessionTTL
)
return
account
,
nil
}
}
}
// 2. 获取可调度账号列表(排除限流和过载的账号)
var
accounts
[]
model
.
Account
var
err
error
if
groupID
!=
nil
{
accounts
,
err
=
s
.
repos
.
Account
.
ListSchedulableByGroupID
(
ctx
,
*
groupID
)
}
else
{
accounts
,
err
=
s
.
repos
.
Account
.
ListSchedulable
(
ctx
)
}
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
// 3. 按优先级+最久未用选择(考虑模型支持)
var
selected
*
model
.
Account
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
// 检查模型支持
if
requestedModel
!=
""
&&
!
acc
.
IsModelSupported
(
requestedModel
)
{
continue
}
if
selected
==
nil
{
selected
=
acc
continue
}
// 优先选择priority值更小的(priority值越小优先级越高)
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
// 优先级相同时,选最久未用的
if
acc
.
LastUsedAt
==
nil
||
(
selected
.
LastUsedAt
!=
nil
&&
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
))
{
selected
=
acc
}
}
}
if
selected
==
nil
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available accounts supporting model: %s"
,
requestedModel
)
}
return
nil
,
errors
.
New
(
"no available accounts"
)
}
// 4. 建立粘性绑定
if
sessionHash
!=
""
{
s
.
rdb
.
Set
(
ctx
,
stickySessionPrefix
+
sessionHash
,
selected
.
ID
,
stickySessionTTL
)
}
return
selected
,
nil
}
// GetAccessToken 获取账号凭证
func
(
s
*
GatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
model
.
Account
)
(
string
,
string
,
error
)
{
switch
account
.
Type
{
case
model
.
AccountTypeOAuth
,
model
.
AccountTypeSetupToken
:
// Both oauth and setup-token use OAuth token flow
return
s
.
getOAuthToken
(
ctx
,
account
)
case
model
.
AccountTypeApiKey
:
apiKey
:=
account
.
GetCredential
(
"api_key"
)
if
apiKey
==
""
{
return
""
,
""
,
errors
.
New
(
"api_key not found in credentials"
)
}
return
apiKey
,
"apikey"
,
nil
default
:
return
""
,
""
,
fmt
.
Errorf
(
"unsupported account type: %s"
,
account
.
Type
)
}
}
func
(
s
*
GatewayService
)
getOAuthToken
(
ctx
context
.
Context
,
account
*
model
.
Account
)
(
string
,
string
,
error
)
{
accessToken
:=
account
.
GetCredential
(
"access_token"
)
expiresAtStr
:=
account
.
GetCredential
(
"expires_at"
)
// 检查是否需要刷新
needRefresh
:=
false
if
expiresAtStr
!=
""
{
expiresAt
,
err
:=
strconv
.
ParseInt
(
expiresAtStr
,
10
,
64
)
if
err
==
nil
&&
time
.
Now
()
.
Unix
()
+
tokenRefreshBuffer
>
expiresAt
{
needRefresh
=
true
}
}
if
needRefresh
||
accessToken
==
""
{
tokenInfo
,
err
:=
s
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
fmt
.
Errorf
(
"refresh token failed: %w"
,
err
)
}
// 更新账号凭证
account
.
Credentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
account
.
Credentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
account
.
Credentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
err
:=
s
.
repos
.
Account
.
Update
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"Failed to update account credentials: %v"
,
err
)
}
return
tokenInfo
.
AccessToken
,
"oauth"
,
nil
}
return
accessToken
,
"oauth"
,
nil
}
// Forward 转发请求到Claude API
func
(
s
*
GatewayService
)
Forward
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
model
.
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
// 解析请求获取model和stream
var
req
struct
{
Model
string
`json:"model"`
Stream
bool
`json:"stream"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse request: %w"
,
err
)
}
// 应用模型映射(仅对apikey类型账号)
originalModel
:=
req
.
Model
if
account
.
Type
==
model
.
AccountTypeApiKey
{
mappedModel
:=
account
.
GetMappedModel
(
req
.
Model
)
if
mappedModel
!=
req
.
Model
{
// 替换请求体中的模型名
body
=
s
.
replaceModelInBody
(
body
,
mappedModel
)
req
.
Model
=
mappedModel
log
.
Printf
(
"Model mapping applied: %s -> %s (account: %s)"
,
originalModel
,
mappedModel
,
account
.
Name
)
}
}
// 获取凭证
token
,
tokenType
,
err
:=
s
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
// 构建上游请求
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
)
if
err
!=
nil
{
return
nil
,
err
}
// 发送请求
resp
,
err
:=
s
.
httpClient
.
Do
(
upstreamReq
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
defer
resp
.
Body
.
Close
()
// 处理401错误:刷新token重试
if
resp
.
StatusCode
==
http
.
StatusUnauthorized
&&
tokenType
==
"oauth"
{
resp
.
Body
.
Close
()
token
,
tokenType
,
err
=
s
.
forceRefreshToken
(
ctx
,
account
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"token refresh failed: %w"
,
err
)
}
upstreamReq
,
err
=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
)
if
err
!=
nil
{
return
nil
,
err
}
resp
,
err
=
s
.
httpClient
.
Do
(
upstreamReq
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"retry request failed: %w"
,
err
)
}
defer
resp
.
Body
.
Close
()
}
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
return
s
.
handleErrorResponse
(
ctx
,
resp
,
c
,
account
)
}
// 处理正常响应
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
if
req
.
Stream
{
streamResult
,
err
:=
s
.
handleStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
originalModel
,
req
.
Model
)
if
err
!=
nil
{
return
nil
,
err
}
usage
=
streamResult
.
usage
firstTokenMs
=
streamResult
.
firstTokenMs
}
else
{
usage
,
err
=
s
.
handleNonStreamingResponse
(
ctx
,
resp
,
c
,
account
,
originalModel
,
req
.
Model
)
if
err
!=
nil
{
return
nil
,
err
}
}
return
&
ForwardResult
{
RequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Usage
:
*
usage
,
Model
:
originalModel
,
// 使用原始模型用于计费和日志
Stream
:
req
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
},
nil
}
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
model
.
Account
,
body
[]
byte
,
token
,
tokenType
string
)
(
*
http
.
Request
,
error
)
{
// 确定目标URL
targetURL
:=
claudeAPIURL
if
account
.
Type
==
model
.
AccountTypeApiKey
{
baseURL
:=
account
.
GetBaseURL
()
targetURL
=
baseURL
+
"/v1/messages"
}
// OAuth账号:应用统一指纹
var
fingerprint
*
Fingerprint
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
// 失败时降级为透传原始headers
}
else
{
fingerprint
=
fp
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
if
newBody
,
err
:=
s
.
identityService
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
fp
.
ClientID
);
err
==
nil
&&
len
(
newBody
)
>
0
{
body
=
newBody
}
}
}
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
targetURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
err
}
// 设置认证头
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
}
else
{
req
.
Header
.
Set
(
"x-api-key"
,
token
)
}
// 白名单透传headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
req
.
Header
.
Add
(
key
,
v
)
}
}
}
// OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头)
if
fingerprint
!=
nil
{
s
.
identityService
.
ApplyFingerprint
(
req
,
fingerprint
)
}
// 确保必要的headers存在
if
req
.
Header
.
Get
(
"Content-Type"
)
==
""
{
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
}
if
req
.
Header
.
Get
(
"anthropic-version"
)
==
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
}
// 处理anthropic-beta header(OAuth账号需要特殊处理)
if
tokenType
==
"oauth"
{
req
.
Header
.
Set
(
"anthropic-beta"
,
s
.
getBetaHeader
(
body
,
c
.
GetHeader
(
"anthropic-beta"
)))
}
// 配置代理
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
:=
account
.
Proxy
.
URL
()
if
proxyURL
!=
""
{
if
parsedURL
,
err
:=
url
.
Parse
(
proxyURL
);
err
==
nil
{
// 计算响应头超时时间(与默认 Transport 保持一致)
responseHeaderTimeout
:=
time
.
Duration
(
s
.
cfg
.
Gateway
.
ResponseHeaderTimeout
)
*
time
.
Second
if
responseHeaderTimeout
==
0
{
responseHeaderTimeout
=
300
*
time
.
Second
}
transport
:=
&
http
.
Transport
{
Proxy
:
http
.
ProxyURL
(
parsedURL
),
MaxIdleConns
:
100
,
MaxIdleConnsPerHost
:
10
,
IdleConnTimeout
:
90
*
time
.
Second
,
ResponseHeaderTimeout
:
responseHeaderTimeout
,
}
s
.
httpClient
.
Transport
=
transport
}
}
}
return
req
,
nil
}
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号,需要确保包含oauth-2025-04-20
func
(
s
*
GatewayService
)
getBetaHeader
(
body
[]
byte
,
clientBetaHeader
string
)
string
{
const
oauthBeta
=
"oauth-2025-04-20"
const
claudeCodeBeta
=
"claude-code-20250219"
// 如果客户端传了anthropic-beta
if
clientBetaHeader
!=
""
{
// 已包含oauth beta则直接返回
if
strings
.
Contains
(
clientBetaHeader
,
oauthBeta
)
{
return
clientBetaHeader
}
// 需要添加oauth beta
parts
:=
strings
.
Split
(
clientBetaHeader
,
","
)
for
i
,
p
:=
range
parts
{
parts
[
i
]
=
strings
.
TrimSpace
(
p
)
}
// 在claude-code-20250219后面插入oauth beta
claudeCodeIdx
:=
-
1
for
i
,
p
:=
range
parts
{
if
p
==
claudeCodeBeta
{
claudeCodeIdx
=
i
break
}
}
if
claudeCodeIdx
>=
0
{
// 在claude-code后面插入
newParts
:=
make
([]
string
,
0
,
len
(
parts
)
+
1
)
newParts
=
append
(
newParts
,
parts
[
:
claudeCodeIdx
+
1
]
...
)
newParts
=
append
(
newParts
,
oauthBeta
)
newParts
=
append
(
newParts
,
parts
[
claudeCodeIdx
+
1
:
]
...
)
return
strings
.
Join
(
newParts
,
","
)
}
// 没有claude-code,放在第一位
return
oauthBeta
+
","
+
clientBetaHeader
}
// 客户端没传,根据模型生成
var
modelID
string
var
reqMap
map
[
string
]
interface
{}
if
json
.
Unmarshal
(
body
,
&
reqMap
)
==
nil
{
if
m
,
ok
:=
reqMap
[
"model"
]
.
(
string
);
ok
{
modelID
=
m
}
}
// haiku模型不需要claude-code beta
if
strings
.
Contains
(
strings
.
ToLower
(
modelID
),
"haiku"
)
{
return
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
}
return
"claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
}
func
(
s
*
GatewayService
)
forceRefreshToken
(
ctx
context
.
Context
,
account
*
model
.
Account
)
(
string
,
string
,
error
)
{
tokenInfo
,
err
:=
s
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
err
}
account
.
Credentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
account
.
Credentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
account
.
Credentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
err
:=
s
.
repos
.
Account
.
Update
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"Failed to update account credentials: %v"
,
err
)
}
return
tokenInfo
.
AccessToken
,
"oauth"
,
nil
}
func
(
s
*
GatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
model
.
Account
)
(
*
ForwardResult
,
error
)
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,返回通用 500 错误(不做任何账号状态处理)
if
!
account
.
ShouldHandleErrorCode
(
resp
.
StatusCode
)
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream gateway error"
,
},
})
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (not in custom error codes)"
,
resp
.
StatusCode
)
}
// 处理上游错误,标记账号状态
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var
errType
,
errMsg
string
var
statusCode
int
switch
resp
.
StatusCode
{
case
401
:
statusCode
=
http
.
StatusBadGateway
errType
=
"upstream_error"
errMsg
=
"Upstream authentication failed, please contact administrator"
case
403
:
statusCode
=
http
.
StatusBadGateway
errType
=
"upstream_error"
errMsg
=
"Upstream access forbidden, please contact administrator"
case
429
:
statusCode
=
http
.
StatusTooManyRequests
errType
=
"rate_limit_error"
errMsg
=
"Upstream rate limit exceeded, please retry later"
case
529
:
statusCode
=
http
.
StatusServiceUnavailable
errType
=
"overloaded_error"
errMsg
=
"Upstream service overloaded, please retry later"
case
500
,
502
,
503
,
504
:
statusCode
=
http
.
StatusBadGateway
errType
=
"upstream_error"
errMsg
=
"Upstream service temporarily unavailable"
default
:
statusCode
=
http
.
StatusBadGateway
errType
=
"upstream_error"
errMsg
=
"Upstream request failed"
}
// 返回自定义错误响应
c
.
JSON
(
statusCode
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
errMsg
,
},
})
return
nil
,
fmt
.
Errorf
(
"upstream error: %d"
,
resp
.
StatusCode
)
}
// streamingResult 流式响应结果
type
streamingResult
struct
{
usage
*
ClaudeUsage
firstTokenMs
*
int
}
func
(
s
*
GatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
model
.
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
)
(
*
streamingResult
,
error
)
{
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
// 设置SSE响应头
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
// 透传其他响应头
if
v
:=
resp
.
Header
.
Get
(
"x-request-id"
);
v
!=
""
{
c
.
Header
(
"x-request-id"
,
v
)
}
w
:=
c
.
Writer
flusher
,
ok
:=
w
.
(
http
.
Flusher
)
if
!
ok
{
return
nil
,
errors
.
New
(
"streaming not supported"
)
}
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
// 设置更大的buffer以处理长行
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
1024
*
1024
)
needModelReplace
:=
originalModel
!=
mappedModel
for
scanner
.
Scan
()
{
line
:=
scanner
.
Text
()
// 如果有模型映射,替换响应中的model字段
if
needModelReplace
&&
strings
.
HasPrefix
(
line
,
"data: "
)
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
}
// 转发行
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
)
flusher
.
Flush
()
// 解析usage数据
if
strings
.
HasPrefix
(
line
,
"data: "
)
{
data
:=
line
[
6
:
]
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
s
.
parseSSEUsage
(
data
,
usage
)
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
}
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func
(
s
*
GatewayService
)
replaceModelInSSELine
(
line
,
fromModel
,
toModel
string
)
string
{
data
:=
line
[
6
:
]
// 去掉 "data: " 前缀
if
data
==
""
||
data
==
"[DONE]"
{
return
line
}
var
event
map
[
string
]
interface
{}
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
event
);
err
!=
nil
{
return
line
}
// 只替换 message_start 事件中的 message.model
if
event
[
"type"
]
!=
"message_start"
{
return
line
}
msg
,
ok
:=
event
[
"message"
]
.
(
map
[
string
]
interface
{})
if
!
ok
{
return
line
}
model
,
ok
:=
msg
[
"model"
]
.
(
string
)
if
!
ok
||
model
!=
fromModel
{
return
line
}
msg
[
"model"
]
=
toModel
newData
,
err
:=
json
.
Marshal
(
event
)
if
err
!=
nil
{
return
line
}
return
"data: "
+
string
(
newData
)
}
func
(
s
*
GatewayService
)
parseSSEUsage
(
data
string
,
usage
*
ClaudeUsage
)
{
// 解析message_start获取input tokens
var
msgStart
struct
{
Type
string
`json:"type"`
Message
struct
{
Usage
ClaudeUsage
`json:"usage"`
}
`json:"message"`
}
if
json
.
Unmarshal
([]
byte
(
data
),
&
msgStart
)
==
nil
&&
msgStart
.
Type
==
"message_start"
{
usage
.
InputTokens
=
msgStart
.
Message
.
Usage
.
InputTokens
usage
.
CacheCreationInputTokens
=
msgStart
.
Message
.
Usage
.
CacheCreationInputTokens
usage
.
CacheReadInputTokens
=
msgStart
.
Message
.
Usage
.
CacheReadInputTokens
}
// 解析message_delta获取output tokens
var
msgDelta
struct
{
Type
string
`json:"type"`
Usage
struct
{
OutputTokens
int
`json:"output_tokens"`
}
`json:"usage"`
}
if
json
.
Unmarshal
([]
byte
(
data
),
&
msgDelta
)
==
nil
&&
msgDelta
.
Type
==
"message_delta"
{
usage
.
OutputTokens
=
msgDelta
.
Usage
.
OutputTokens
}
}
func
(
s
*
GatewayService
)
handleNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
model
.
Account
,
originalModel
,
mappedModel
string
)
(
*
ClaudeUsage
,
error
)
{
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
err
}
// 解析usage
var
response
struct
{
Usage
ClaudeUsage
`json:"usage"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
response
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse response: %w"
,
err
)
}
// 如果有模型映射,替换响应中的model字段
if
originalModel
!=
mappedModel
{
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
// 透传响应头
for
key
,
values
:=
range
resp
.
Header
{
for
_
,
value
:=
range
values
{
c
.
Header
(
key
,
value
)
}
}
// 写入响应
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
body
)
return
&
response
.
Usage
,
nil
}
// replaceModelInResponseBody 替换响应体中的model字段
func
(
s
*
GatewayService
)
replaceModelInResponseBody
(
body
[]
byte
,
fromModel
,
toModel
string
)
[]
byte
{
var
resp
map
[
string
]
interface
{}
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
body
}
model
,
ok
:=
resp
[
"model"
]
.
(
string
)
if
!
ok
||
model
!=
fromModel
{
return
body
}
resp
[
"model"
]
=
toModel
newBody
,
err
:=
json
.
Marshal
(
resp
)
if
err
!=
nil
{
return
body
}
return
newBody
}
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
Result
*
ForwardResult
ApiKey
*
model
.
ApiKey
User
*
model
.
User
Account
*
model
.
Account
Subscription
*
model
.
UserSubscription
// 可选:订阅信息
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func
(
s
*
GatewayService
)
RecordUsage
(
ctx
context
.
Context
,
input
*
RecordUsageInput
)
error
{
result
:=
input
.
Result
apiKey
:=
input
.
ApiKey
user
:=
input
.
User
account
:=
input
.
Account
subscription
:=
input
.
Subscription
// 计算费用
tokens
:=
UsageTokens
{
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
}
// 获取费率倍数
multiplier
:=
s
.
cfg
.
Default
.
RateMultiplier
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
multiplier
=
apiKey
.
Group
.
RateMultiplier
}
cost
,
err
:=
s
.
billingService
.
CalculateCost
(
result
.
Model
,
tokens
,
multiplier
)
if
err
!=
nil
{
log
.
Printf
(
"Calculate cost failed: %v"
,
err
)
// 使用默认费用继续
cost
=
&
CostBreakdown
{
ActualCost
:
0
}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling
:=
subscription
!=
nil
&&
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
billingType
:=
model
.
BillingTypeBalance
if
isSubscriptionBilling
{
billingType
=
model
.
BillingTypeSubscription
}
// 创建使用日志
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
usageLog
:=
&
model
.
UsageLog
{
UserID
:
user
.
ID
,
ApiKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
CacheCreationTokens
:
result
.
Usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
result
.
Usage
.
CacheReadInputTokens
,
InputCost
:
cost
.
InputCost
,
OutputCost
:
cost
.
OutputCost
,
CacheCreationCost
:
cost
.
CacheCreationCost
,
CacheReadCost
:
cost
.
CacheReadCost
,
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
FirstTokenMs
:
result
.
FirstTokenMs
,
CreatedAt
:
time
.
Now
(),
}
// 添加分组和订阅关联
if
apiKey
.
GroupID
!=
nil
{
usageLog
.
GroupID
=
apiKey
.
GroupID
}
if
subscription
!=
nil
{
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
if
err
:=
s
.
repos
.
UsageLog
.
Create
(
ctx
,
usageLog
);
err
!=
nil
{
log
.
Printf
(
"Create usage log failed: %v"
,
err
)
}
// 根据计费类型执行扣费
if
isSubscriptionBilling
{
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if
cost
.
TotalCost
>
0
{
if
err
:=
s
.
repos
.
UserSubscription
.
IncrementUsage
(
ctx
,
subscription
.
ID
,
cost
.
TotalCost
);
err
!=
nil
{
log
.
Printf
(
"Increment subscription usage failed: %v"
,
err
)
}
// 异步更新订阅缓存
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
err
:=
s
.
billingCacheService
.
UpdateSubscriptionUsage
(
cacheCtx
,
user
.
ID
,
*
apiKey
.
GroupID
,
cost
.
TotalCost
);
err
!=
nil
{
log
.
Printf
(
"Update subscription cache failed: %v"
,
err
)
}
}()
}
}
else
{
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if
cost
.
ActualCost
>
0
{
if
err
:=
s
.
repos
.
User
.
DeductBalance
(
ctx
,
user
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Deduct balance failed: %v"
,
err
)
}
// 异步更新余额缓存
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
err
:=
s
.
billingCacheService
.
DeductBalanceCache
(
cacheCtx
,
user
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
log
.
Printf
(
"Update balance cache failed: %v"
,
err
)
}
}()
}
}
// 更新账号最后使用时间
if
err
:=
s
.
repos
.
Account
.
UpdateLastUsed
(
ctx
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Update last used failed: %v"
,
err
)
}
return
nil
}
backend/internal/service/group_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var
(
ErrGroupNotFound
=
errors
.
New
(
"group not found"
)
ErrGroupExists
=
errors
.
New
(
"group name already exists"
)
)
// CreateGroupRequest 创建分组请求
type
CreateGroupRequest
struct
{
Name
string
`json:"name"`
Description
string
`json:"description"`
RateMultiplier
float64
`json:"rate_multiplier"`
IsExclusive
bool
`json:"is_exclusive"`
}
// UpdateGroupRequest 更新分组请求
type
UpdateGroupRequest
struct
{
Name
*
string
`json:"name"`
Description
*
string
`json:"description"`
RateMultiplier
*
float64
`json:"rate_multiplier"`
IsExclusive
*
bool
`json:"is_exclusive"`
Status
*
string
`json:"status"`
}
// GroupService 分组管理服务
type
GroupService
struct
{
groupRepo
*
repository
.
GroupRepository
}
// NewGroupService 创建分组服务实例
func
NewGroupService
(
groupRepo
*
repository
.
GroupRepository
)
*
GroupService
{
return
&
GroupService
{
groupRepo
:
groupRepo
,
}
}
// Create 创建分组
func
(
s
*
GroupService
)
Create
(
ctx
context
.
Context
,
req
CreateGroupRequest
)
(
*
model
.
Group
,
error
)
{
// 检查名称是否已存在
exists
,
err
:=
s
.
groupRepo
.
ExistsByName
(
ctx
,
req
.
Name
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check group exists: %w"
,
err
)
}
if
exists
{
return
nil
,
ErrGroupExists
}
// 创建分组
group
:=
&
model
.
Group
{
Name
:
req
.
Name
,
Description
:
req
.
Description
,
RateMultiplier
:
req
.
RateMultiplier
,
IsExclusive
:
req
.
IsExclusive
,
Status
:
model
.
StatusActive
,
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create group: %w"
,
err
)
}
return
group
,
nil
}
// GetByID 根据ID获取分组
func
(
s
*
GroupService
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Group
,
error
)
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrGroupNotFound
}
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
return
group
,
nil
}
// List 获取分组列表
func
(
s
*
GroupService
)
List
(
ctx
context
.
Context
,
params
repository
.
PaginationParams
)
([]
model
.
Group
,
*
repository
.
PaginationResult
,
error
)
{
groups
,
pagination
,
err
:=
s
.
groupRepo
.
List
(
ctx
,
params
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list groups: %w"
,
err
)
}
return
groups
,
pagination
,
nil
}
// ListActive 获取活跃分组列表
func
(
s
*
GroupService
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Group
,
error
)
{
groups
,
err
:=
s
.
groupRepo
.
ListActive
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list active groups: %w"
,
err
)
}
return
groups
,
nil
}
// Update 更新分组
func
(
s
*
GroupService
)
Update
(
ctx
context
.
Context
,
id
int64
,
req
UpdateGroupRequest
)
(
*
model
.
Group
,
error
)
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrGroupNotFound
}
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
// 更新字段
if
req
.
Name
!=
nil
&&
*
req
.
Name
!=
group
.
Name
{
// 检查新名称是否已存在
exists
,
err
:=
s
.
groupRepo
.
ExistsByName
(
ctx
,
*
req
.
Name
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"check group exists: %w"
,
err
)
}
if
exists
{
return
nil
,
ErrGroupExists
}
group
.
Name
=
*
req
.
Name
}
if
req
.
Description
!=
nil
{
group
.
Description
=
*
req
.
Description
}
if
req
.
RateMultiplier
!=
nil
{
group
.
RateMultiplier
=
*
req
.
RateMultiplier
}
if
req
.
IsExclusive
!=
nil
{
group
.
IsExclusive
=
*
req
.
IsExclusive
}
if
req
.
Status
!=
nil
{
group
.
Status
=
*
req
.
Status
}
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update group: %w"
,
err
)
}
return
group
,
nil
}
// Delete 删除分组
func
(
s
*
GroupService
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
// 检查分组是否存在
_
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
ErrGroupNotFound
}
return
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
if
err
:=
s
.
groupRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete group: %w"
,
err
)
}
return
nil
}
// GetStats 获取分组统计信息
func
(
s
*
GroupService
)
GetStats
(
ctx
context
.
Context
,
id
int64
)
(
map
[
string
]
interface
{},
error
)
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
nil
,
ErrGroupNotFound
}
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
// 获取账号数量
accountCount
,
err
:=
s
.
groupRepo
.
GetAccountCount
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get account count: %w"
,
err
)
}
stats
:=
map
[
string
]
interface
{}{
"id"
:
group
.
ID
,
"name"
:
group
.
Name
,
"rate_multiplier"
:
group
.
RateMultiplier
,
"is_exclusive"
:
group
.
IsExclusive
,
"status"
:
group
.
Status
,
"account_count"
:
accountCount
,
}
return
stats
,
nil
}
backend/internal/service/identity_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net/http"
"regexp"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
const
(
// Redis key prefix
identityFingerprintKey
=
"identity:fingerprint:"
)
// 预编译正则表达式(避免每次调用重新编译)
var
(
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
userIDRegex
=
regexp
.
MustCompile
(
`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`
)
// 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex
=
regexp
.
MustCompile
(
`/(\d+)\.(\d+)\.(\d+)`
)
)
// Fingerprint 存储的指纹数据结构
type
Fingerprint
struct
{
ClientID
string
`json:"client_id"`
// 64位hex客户端ID(首次随机生成)
UserAgent
string
`json:"user_agent"`
// User-Agent
StainlessLang
string
`json:"x_stainless_lang"`
// x-stainless-lang
StainlessPackageVersion
string
`json:"x_stainless_package_version"`
// x-stainless-package-version
StainlessOS
string
`json:"x_stainless_os"`
// x-stainless-os
StainlessArch
string
`json:"x_stainless_arch"`
// x-stainless-arch
StainlessRuntime
string
`json:"x_stainless_runtime"`
// x-stainless-runtime
StainlessRuntimeVersion
string
`json:"x_stainless_runtime_version"`
// x-stainless-runtime-version
}
// 默认指纹值(当客户端未提供时使用)
var
defaultFingerprint
=
Fingerprint
{
UserAgent
:
"claude-cli/2.0.62 (external, cli)"
,
StainlessLang
:
"js"
,
StainlessPackageVersion
:
"0.52.0"
,
StainlessOS
:
"Linux"
,
StainlessArch
:
"x64"
,
StainlessRuntime
:
"node"
,
StainlessRuntimeVersion
:
"v22.14.0"
,
}
// IdentityService 管理OAuth账号的请求身份指纹
type
IdentityService
struct
{
rdb
*
redis
.
Client
}
// NewIdentityService 创建新的IdentityService
func
NewIdentityService
(
rdb
*
redis
.
Client
)
*
IdentityService
{
return
&
IdentityService
{
rdb
:
rdb
}
}
// GetOrCreateFingerprint 获取或创建账号的指纹
// 如果缓存存在,检测user-agent版本,新版本则更新
// 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存
func
(
s
*
IdentityService
)
GetOrCreateFingerprint
(
ctx
context
.
Context
,
accountID
int64
,
headers
http
.
Header
)
(
*
Fingerprint
,
error
)
{
key
:=
identityFingerprintKey
+
strconv
.
FormatInt
(
accountID
,
10
)
// 尝试从Redis获取缓存的指纹
data
,
err
:=
s
.
rdb
.
Get
(
ctx
,
key
)
.
Bytes
()
if
err
==
nil
&&
len
(
data
)
>
0
{
// 缓存存在,解析指纹
var
cached
Fingerprint
if
err
:=
json
.
Unmarshal
(
data
,
&
cached
);
err
==
nil
{
// 检查客户端的user-agent是否是更新版本
clientUA
:=
headers
.
Get
(
"User-Agent"
)
if
clientUA
!=
""
&&
isNewerVersion
(
clientUA
,
cached
.
UserAgent
)
{
// 更新user-agent
cached
.
UserAgent
=
clientUA
// 保存更新后的指纹
if
newData
,
err
:=
json
.
Marshal
(
cached
);
err
==
nil
{
s
.
rdb
.
Set
(
ctx
,
key
,
newData
,
0
)
// 永不过期
}
log
.
Printf
(
"Updated fingerprint user-agent for account %d: %s"
,
accountID
,
clientUA
)
}
return
&
cached
,
nil
}
}
// 缓存不存在或解析失败,创建新指纹
fp
:=
s
.
createFingerprintFromHeaders
(
headers
)
// 生成随机ClientID
fp
.
ClientID
=
generateClientID
()
// 保存到Redis(永不过期)
if
data
,
err
:=
json
.
Marshal
(
fp
);
err
==
nil
{
if
err
:=
s
.
rdb
.
Set
(
ctx
,
key
,
data
,
0
)
.
Err
();
err
!=
nil
{
log
.
Printf
(
"Warning: failed to cache fingerprint for account %d: %v"
,
accountID
,
err
)
}
}
log
.
Printf
(
"Created new fingerprint for account %d with client_id: %s"
,
accountID
,
fp
.
ClientID
)
return
fp
,
nil
}
// createFingerprintFromHeaders 从请求头创建指纹
func
(
s
*
IdentityService
)
createFingerprintFromHeaders
(
headers
http
.
Header
)
*
Fingerprint
{
fp
:=
&
Fingerprint
{}
// 获取User-Agent
if
ua
:=
headers
.
Get
(
"User-Agent"
);
ua
!=
""
{
fp
.
UserAgent
=
ua
}
else
{
fp
.
UserAgent
=
defaultFingerprint
.
UserAgent
}
// 获取x-stainless-*头,如果没有则使用默认值
fp
.
StainlessLang
=
getHeaderOrDefault
(
headers
,
"X-Stainless-Lang"
,
defaultFingerprint
.
StainlessLang
)
fp
.
StainlessPackageVersion
=
getHeaderOrDefault
(
headers
,
"X-Stainless-Package-Version"
,
defaultFingerprint
.
StainlessPackageVersion
)
fp
.
StainlessOS
=
getHeaderOrDefault
(
headers
,
"X-Stainless-OS"
,
defaultFingerprint
.
StainlessOS
)
fp
.
StainlessArch
=
getHeaderOrDefault
(
headers
,
"X-Stainless-Arch"
,
defaultFingerprint
.
StainlessArch
)
fp
.
StainlessRuntime
=
getHeaderOrDefault
(
headers
,
"X-Stainless-Runtime"
,
defaultFingerprint
.
StainlessRuntime
)
fp
.
StainlessRuntimeVersion
=
getHeaderOrDefault
(
headers
,
"X-Stainless-Runtime-Version"
,
defaultFingerprint
.
StainlessRuntimeVersion
)
return
fp
}
// getHeaderOrDefault 获取header值,如果不存在则返回默认值
func
getHeaderOrDefault
(
headers
http
.
Header
,
key
,
defaultValue
string
)
string
{
if
v
:=
headers
.
Get
(
key
);
v
!=
""
{
return
v
}
return
defaultValue
}
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
func
(
s
*
IdentityService
)
ApplyFingerprint
(
req
*
http
.
Request
,
fp
*
Fingerprint
)
{
if
fp
==
nil
{
return
}
// 设置User-Agent
if
fp
.
UserAgent
!=
""
{
req
.
Header
.
Set
(
"User-Agent"
,
fp
.
UserAgent
)
}
// 设置x-stainless-*头(使用正确的大小写)
if
fp
.
StainlessLang
!=
""
{
req
.
Header
.
Set
(
"X-Stainless-Lang"
,
fp
.
StainlessLang
)
}
if
fp
.
StainlessPackageVersion
!=
""
{
req
.
Header
.
Set
(
"X-Stainless-Package-Version"
,
fp
.
StainlessPackageVersion
)
}
if
fp
.
StainlessOS
!=
""
{
req
.
Header
.
Set
(
"X-Stainless-OS"
,
fp
.
StainlessOS
)
}
if
fp
.
StainlessArch
!=
""
{
req
.
Header
.
Set
(
"X-Stainless-Arch"
,
fp
.
StainlessArch
)
}
if
fp
.
StainlessRuntime
!=
""
{
req
.
Header
.
Set
(
"X-Stainless-Runtime"
,
fp
.
StainlessRuntime
)
}
if
fp
.
StainlessRuntimeVersion
!=
""
{
req
.
Header
.
Set
(
"X-Stainless-Runtime-Version"
,
fp
.
StainlessRuntimeVersion
)
}
}
// RewriteUserID 重写body中的metadata.user_id
// 输入格式:user_{clientId}_account__session_{sessionUUID}
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
func
(
s
*
IdentityService
)
RewriteUserID
(
body
[]
byte
,
accountID
int64
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
if
len
(
body
)
==
0
||
accountUUID
==
""
||
cachedClientID
==
""
{
return
body
,
nil
}
// 解析JSON
var
reqMap
map
[
string
]
interface
{}
if
err
:=
json
.
Unmarshal
(
body
,
&
reqMap
);
err
!=
nil
{
return
body
,
nil
}
metadata
,
ok
:=
reqMap
[
"metadata"
]
.
(
map
[
string
]
interface
{})
if
!
ok
{
return
body
,
nil
}
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
if
!
ok
||
userID
==
""
{
return
body
,
nil
}
// 匹配格式: user_{64位hex}_account__session_{uuid}
matches
:=
userIDRegex
.
FindStringSubmatch
(
userID
)
if
matches
==
nil
{
return
body
,
nil
}
sessionTail
:=
matches
[
1
]
// 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed
:=
fmt
.
Sprintf
(
"%d::%s"
,
accountID
,
sessionTail
)
newSessionHash
:=
generateUUIDFromSeed
(
seed
)
// 构建新的user_id
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
newUserID
:=
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
cachedClientID
,
accountUUID
,
newSessionHash
)
metadata
[
"user_id"
]
=
newUserID
reqMap
[
"metadata"
]
=
metadata
return
json
.
Marshal
(
reqMap
)
}
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
func
generateClientID
()
string
{
b
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
// 极罕见的情况,使用时间戳+固定值作为fallback
log
.
Printf
(
"Warning: crypto/rand.Read failed: %v, using fallback"
,
err
)
// 使用SHA256(当前纳秒时间)作为fallback
h
:=
sha256
.
Sum256
([]
byte
(
fmt
.
Sprintf
(
"%d"
,
time
.
Now
()
.
UnixNano
())))
return
hex
.
EncodeToString
(
h
[
:
])
}
return
hex
.
EncodeToString
(
b
)
}
// generateUUIDFromSeed 从种子生成确定性UUID v4格式字符串
func
generateUUIDFromSeed
(
seed
string
)
string
{
hash
:=
sha256
.
Sum256
([]
byte
(
seed
))
bytes
:=
hash
[
:
16
]
// 设置UUID v4版本和变体位
bytes
[
6
]
=
(
bytes
[
6
]
&
0x0f
)
|
0x40
bytes
[
8
]
=
(
bytes
[
8
]
&
0x3f
)
|
0x80
return
fmt
.
Sprintf
(
"%x-%x-%x-%x-%x"
,
bytes
[
0
:
4
],
bytes
[
4
:
6
],
bytes
[
6
:
8
],
bytes
[
8
:
10
],
bytes
[
10
:
16
])
}
// parseUserAgentVersion 解析user-agent版本号
// 例如:claude-cli/2.0.62 -> (2, 0, 62)
func
parseUserAgentVersion
(
ua
string
)
(
major
,
minor
,
patch
int
,
ok
bool
)
{
// 匹配 xxx/x.y.z 格式
matches
:=
userAgentVersionRegex
.
FindStringSubmatch
(
ua
)
if
len
(
matches
)
!=
4
{
return
0
,
0
,
0
,
false
}
major
,
_
=
strconv
.
Atoi
(
matches
[
1
])
minor
,
_
=
strconv
.
Atoi
(
matches
[
2
])
patch
,
_
=
strconv
.
Atoi
(
matches
[
3
])
return
major
,
minor
,
patch
,
true
}
// isNewerVersion 比较版本号,判断newUA是否比cachedUA更新
func
isNewerVersion
(
newUA
,
cachedUA
string
)
bool
{
newMajor
,
newMinor
,
newPatch
,
newOk
:=
parseUserAgentVersion
(
newUA
)
cachedMajor
,
cachedMinor
,
cachedPatch
,
cachedOk
:=
parseUserAgentVersion
(
cachedUA
)
if
!
newOk
||
!
cachedOk
{
return
false
}
// 比较版本号
if
newMajor
>
cachedMajor
{
return
true
}
if
newMajor
<
cachedMajor
{
return
false
}
if
newMinor
>
cachedMinor
{
return
true
}
if
newMinor
<
cachedMinor
{
return
false
}
return
newPatch
>
cachedPatch
}
backend/internal/service/oauth_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/oauth"
"sub2api/internal/repository"
"github.com/imroc/req/v3"
)
// OAuthService handles OAuth authentication flows
type
OAuthService
struct
{
sessionStore
*
oauth
.
SessionStore
proxyRepo
*
repository
.
ProxyRepository
}
// NewOAuthService creates a new OAuth service
func
NewOAuthService
(
proxyRepo
*
repository
.
ProxyRepository
)
*
OAuthService
{
return
&
OAuthService
{
sessionStore
:
oauth
.
NewSessionStore
(),
proxyRepo
:
proxyRepo
,
}
}
// GenerateAuthURLResult contains the authorization URL and session info
type
GenerateAuthURLResult
struct
{
AuthURL
string
`json:"auth_url"`
SessionID
string
`json:"session_id"`
}
// GenerateAuthURL generates an OAuth authorization URL with full scope
func
(
s
*
OAuthService
)
GenerateAuthURL
(
ctx
context
.
Context
,
proxyID
*
int64
)
(
*
GenerateAuthURLResult
,
error
)
{
scope
:=
fmt
.
Sprintf
(
"%s %s"
,
oauth
.
ScopeProfile
,
oauth
.
ScopeInference
)
return
s
.
generateAuthURLWithScope
(
ctx
,
scope
,
proxyID
)
}
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
func
(
s
*
OAuthService
)
GenerateSetupTokenURL
(
ctx
context
.
Context
,
proxyID
*
int64
)
(
*
GenerateAuthURLResult
,
error
)
{
scope
:=
oauth
.
ScopeInference
return
s
.
generateAuthURLWithScope
(
ctx
,
scope
,
proxyID
)
}
func
(
s
*
OAuthService
)
generateAuthURLWithScope
(
ctx
context
.
Context
,
scope
string
,
proxyID
*
int64
)
(
*
GenerateAuthURLResult
,
error
)
{
// Generate PKCE values
state
,
err
:=
oauth
.
GenerateState
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to generate state: %w"
,
err
)
}
codeVerifier
,
err
:=
oauth
.
GenerateCodeVerifier
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to generate code verifier: %w"
,
err
)
}
codeChallenge
:=
oauth
.
GenerateCodeChallenge
(
codeVerifier
)
// Generate session ID
sessionID
,
err
:=
oauth
.
GenerateSessionID
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to generate session ID: %w"
,
err
)
}
// Get proxy URL if specified
var
proxyURL
string
if
proxyID
!=
nil
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
proxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
}
}
// Store session
session
:=
&
oauth
.
OAuthSession
{
State
:
state
,
CodeVerifier
:
codeVerifier
,
Scope
:
scope
,
ProxyURL
:
proxyURL
,
CreatedAt
:
time
.
Now
(),
}
s
.
sessionStore
.
Set
(
sessionID
,
session
)
// Build authorization URL
authURL
:=
oauth
.
BuildAuthorizationURL
(
state
,
codeChallenge
,
scope
)
return
&
GenerateAuthURLResult
{
AuthURL
:
authURL
,
SessionID
:
sessionID
,
},
nil
}
// ExchangeCodeInput represents the input for code exchange
type
ExchangeCodeInput
struct
{
SessionID
string
Code
string
ProxyID
*
int64
}
// TokenInfo represents the token information stored in credentials
type
TokenInfo
struct
{
AccessToken
string
`json:"access_token"`
TokenType
string
`json:"token_type"`
ExpiresIn
int64
`json:"expires_in"`
ExpiresAt
int64
`json:"expires_at"`
RefreshToken
string
`json:"refresh_token,omitempty"`
Scope
string
`json:"scope,omitempty"`
OrgUUID
string
`json:"org_uuid,omitempty"`
AccountUUID
string
`json:"account_uuid,omitempty"`
}
// ExchangeCode exchanges authorization code for tokens
func
(
s
*
OAuthService
)
ExchangeCode
(
ctx
context
.
Context
,
input
*
ExchangeCodeInput
)
(
*
TokenInfo
,
error
)
{
// Get session
session
,
ok
:=
s
.
sessionStore
.
Get
(
input
.
SessionID
)
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"session not found or expired"
)
}
// Get proxy URL
proxyURL
:=
session
.
ProxyURL
if
input
.
ProxyID
!=
nil
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
input
.
ProxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
}
}
// Exchange code for token
tokenInfo
,
err
:=
s
.
exchangeCodeForToken
(
ctx
,
input
.
Code
,
session
.
CodeVerifier
,
session
.
State
,
proxyURL
)
if
err
!=
nil
{
return
nil
,
err
}
// Delete session after successful exchange
s
.
sessionStore
.
Delete
(
input
.
SessionID
)
return
tokenInfo
,
nil
}
// CookieAuthInput represents the input for cookie-based authentication
type
CookieAuthInput
struct
{
SessionKey
string
ProxyID
*
int64
Scope
string
// "full" or "inference"
}
// CookieAuth performs OAuth using sessionKey (cookie-based auto-auth)
func
(
s
*
OAuthService
)
CookieAuth
(
ctx
context
.
Context
,
input
*
CookieAuthInput
)
(
*
TokenInfo
,
error
)
{
// Get proxy URL if specified
var
proxyURL
string
if
input
.
ProxyID
!=
nil
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
input
.
ProxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
}
}
// Determine scope
scope
:=
fmt
.
Sprintf
(
"%s %s"
,
oauth
.
ScopeProfile
,
oauth
.
ScopeInference
)
if
input
.
Scope
==
"inference"
{
scope
=
oauth
.
ScopeInference
}
// Step 1: Get organization info using sessionKey
orgUUID
,
err
:=
s
.
getOrganizationUUID
(
ctx
,
input
.
SessionKey
,
proxyURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get organization info: %w"
,
err
)
}
// Step 2: Generate PKCE values
codeVerifier
,
err
:=
oauth
.
GenerateCodeVerifier
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to generate code verifier: %w"
,
err
)
}
codeChallenge
:=
oauth
.
GenerateCodeChallenge
(
codeVerifier
)
state
,
err
:=
oauth
.
GenerateState
()
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to generate state: %w"
,
err
)
}
// Step 3: Get authorization code using cookie
authCode
,
err
:=
s
.
getAuthorizationCode
(
ctx
,
input
.
SessionKey
,
orgUUID
,
scope
,
codeChallenge
,
state
,
proxyURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get authorization code: %w"
,
err
)
}
// Step 4: Exchange code for token
tokenInfo
,
err
:=
s
.
exchangeCodeForToken
(
ctx
,
authCode
,
codeVerifier
,
state
,
proxyURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to exchange code: %w"
,
err
)
}
// Ensure org_uuid is set (from step 1 if not from token response)
if
tokenInfo
.
OrgUUID
==
""
&&
orgUUID
!=
""
{
tokenInfo
.
OrgUUID
=
orgUUID
log
.
Printf
(
"[OAuth] Set org_uuid from cookie auth: %s"
,
orgUUID
)
}
return
tokenInfo
,
nil
}
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
func
(
s
*
OAuthService
)
getOrganizationUUID
(
ctx
context
.
Context
,
sessionKey
,
proxyURL
string
)
(
string
,
error
)
{
client
:=
s
.
createReqClient
(
proxyURL
)
var
orgs
[]
struct
{
UUID
string
`json:"uuid"`
}
targetURL
:=
"https://claude.ai/api/organizations"
log
.
Printf
(
"[OAuth] Step 1: Getting organization UUID from %s"
,
targetURL
)
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetCookies
(
&
http
.
Cookie
{
Name
:
"sessionKey"
,
Value
:
sessionKey
,
})
.
SetSuccessResult
(
&
orgs
)
.
Get
(
targetURL
)
if
err
!=
nil
{
log
.
Printf
(
"[OAuth] Step 1 FAILED - Request error: %v"
,
err
)
return
""
,
fmt
.
Errorf
(
"request failed: %w"
,
err
)
}
log
.
Printf
(
"[OAuth] Step 1 Response - Status: %d, Body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
if
!
resp
.
IsSuccessState
()
{
return
""
,
fmt
.
Errorf
(
"failed to get organizations: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
}
if
len
(
orgs
)
==
0
{
return
""
,
fmt
.
Errorf
(
"no organizations found"
)
}
log
.
Printf
(
"[OAuth] Step 1 SUCCESS - Got org UUID: %s"
,
orgs
[
0
]
.
UUID
)
return
orgs
[
0
]
.
UUID
,
nil
}
// getAuthorizationCode gets the authorization code using sessionKey
func
(
s
*
OAuthService
)
getAuthorizationCode
(
ctx
context
.
Context
,
sessionKey
,
orgUUID
,
scope
,
codeChallenge
,
state
,
proxyURL
string
)
(
string
,
error
)
{
client
:=
s
.
createReqClient
(
proxyURL
)
authURL
:=
fmt
.
Sprintf
(
"https://claude.ai/v1/oauth/%s/authorize"
,
orgUUID
)
// Build request body - must include organization_uuid as per CRS
reqBody
:=
map
[
string
]
interface
{}{
"response_type"
:
"code"
,
"client_id"
:
oauth
.
ClientID
,
"organization_uuid"
:
orgUUID
,
// Required field!
"redirect_uri"
:
oauth
.
RedirectURI
,
"scope"
:
scope
,
"state"
:
state
,
"code_challenge"
:
codeChallenge
,
"code_challenge_method"
:
"S256"
,
}
reqBodyJSON
,
_
:=
json
.
Marshal
(
reqBody
)
log
.
Printf
(
"[OAuth] Step 2: Getting authorization code from %s"
,
authURL
)
log
.
Printf
(
"[OAuth] Step 2 Request Body: %s"
,
string
(
reqBodyJSON
))
// Response contains redirect_uri with code, not direct code field
var
result
struct
{
RedirectURI
string
`json:"redirect_uri"`
}
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetCookies
(
&
http
.
Cookie
{
Name
:
"sessionKey"
,
Value
:
sessionKey
,
})
.
SetHeader
(
"Accept"
,
"application/json"
)
.
SetHeader
(
"Accept-Language"
,
"en-US,en;q=0.9"
)
.
SetHeader
(
"Cache-Control"
,
"no-cache"
)
.
SetHeader
(
"Origin"
,
"https://claude.ai"
)
.
SetHeader
(
"Referer"
,
"https://claude.ai/new"
)
.
SetHeader
(
"Content-Type"
,
"application/json"
)
.
SetBody
(
reqBody
)
.
SetSuccessResult
(
&
result
)
.
Post
(
authURL
)
if
err
!=
nil
{
log
.
Printf
(
"[OAuth] Step 2 FAILED - Request error: %v"
,
err
)
return
""
,
fmt
.
Errorf
(
"request failed: %w"
,
err
)
}
log
.
Printf
(
"[OAuth] Step 2 Response - Status: %d, Body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
if
!
resp
.
IsSuccessState
()
{
return
""
,
fmt
.
Errorf
(
"failed to get authorization code: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
}
if
result
.
RedirectURI
==
""
{
return
""
,
fmt
.
Errorf
(
"no redirect_uri in response"
)
}
// Parse redirect_uri to extract code and state
parsedURL
,
err
:=
url
.
Parse
(
result
.
RedirectURI
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"failed to parse redirect_uri: %w"
,
err
)
}
queryParams
:=
parsedURL
.
Query
()
authCode
:=
queryParams
.
Get
(
"code"
)
responseState
:=
queryParams
.
Get
(
"state"
)
if
authCode
==
""
{
return
""
,
fmt
.
Errorf
(
"no authorization code in redirect_uri"
)
}
// Combine code with state if present (as CRS does)
fullCode
:=
authCode
if
responseState
!=
""
{
fullCode
=
authCode
+
"#"
+
responseState
}
log
.
Printf
(
"[OAuth] Step 2 SUCCESS - Got authorization code: %s..."
,
authCode
[
:
20
])
return
fullCode
,
nil
}
// exchangeCodeForToken exchanges authorization code for tokens
func
(
s
*
OAuthService
)
exchangeCodeForToken
(
ctx
context
.
Context
,
code
,
codeVerifier
,
state
,
proxyURL
string
)
(
*
TokenInfo
,
error
)
{
client
:=
s
.
createReqClient
(
proxyURL
)
// Parse code#state format if present
authCode
:=
code
codeState
:=
""
if
parts
:=
strings
.
Split
(
code
,
"#"
);
len
(
parts
)
>
1
{
authCode
=
parts
[
0
]
codeState
=
parts
[
1
]
}
// Build JSON body as CRS does (not form data!)
reqBody
:=
map
[
string
]
interface
{}{
"code"
:
authCode
,
"grant_type"
:
"authorization_code"
,
"client_id"
:
oauth
.
ClientID
,
"redirect_uri"
:
oauth
.
RedirectURI
,
"code_verifier"
:
codeVerifier
,
}
// Add state if present
if
codeState
!=
""
{
reqBody
[
"state"
]
=
codeState
}
reqBodyJSON
,
_
:=
json
.
Marshal
(
reqBody
)
log
.
Printf
(
"[OAuth] Step 3: Exchanging code for token at %s"
,
oauth
.
TokenURL
)
log
.
Printf
(
"[OAuth] Step 3 Request Body: %s"
,
string
(
reqBodyJSON
))
var
tokenResp
oauth
.
TokenResponse
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"Content-Type"
,
"application/json"
)
.
SetBody
(
reqBody
)
.
SetSuccessResult
(
&
tokenResp
)
.
Post
(
oauth
.
TokenURL
)
if
err
!=
nil
{
log
.
Printf
(
"[OAuth] Step 3 FAILED - Request error: %v"
,
err
)
return
nil
,
fmt
.
Errorf
(
"request failed: %w"
,
err
)
}
log
.
Printf
(
"[OAuth] Step 3 Response - Status: %d, Body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
if
!
resp
.
IsSuccessState
()
{
return
nil
,
fmt
.
Errorf
(
"token exchange failed: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
}
log
.
Printf
(
"[OAuth] Step 3 SUCCESS - Got access token"
)
tokenInfo
:=
&
TokenInfo
{
AccessToken
:
tokenResp
.
AccessToken
,
TokenType
:
tokenResp
.
TokenType
,
ExpiresIn
:
tokenResp
.
ExpiresIn
,
ExpiresAt
:
time
.
Now
()
.
Unix
()
+
tokenResp
.
ExpiresIn
,
RefreshToken
:
tokenResp
.
RefreshToken
,
Scope
:
tokenResp
.
Scope
,
}
// Extract org_uuid and account_uuid from response
if
tokenResp
.
Organization
!=
nil
&&
tokenResp
.
Organization
.
UUID
!=
""
{
tokenInfo
.
OrgUUID
=
tokenResp
.
Organization
.
UUID
log
.
Printf
(
"[OAuth] Got org_uuid: %s"
,
tokenInfo
.
OrgUUID
)
}
if
tokenResp
.
Account
!=
nil
&&
tokenResp
.
Account
.
UUID
!=
""
{
tokenInfo
.
AccountUUID
=
tokenResp
.
Account
.
UUID
log
.
Printf
(
"[OAuth] Got account_uuid: %s"
,
tokenInfo
.
AccountUUID
)
}
return
tokenInfo
,
nil
}
// RefreshToken refreshes an OAuth token
func
(
s
*
OAuthService
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
string
,
proxyURL
string
)
(
*
TokenInfo
,
error
)
{
client
:=
s
.
createReqClient
(
proxyURL
)
formData
:=
url
.
Values
{}
formData
.
Set
(
"grant_type"
,
"refresh_token"
)
formData
.
Set
(
"refresh_token"
,
refreshToken
)
formData
.
Set
(
"client_id"
,
oauth
.
ClientID
)
var
tokenResp
oauth
.
TokenResponse
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetFormDataFromValues
(
formData
)
.
SetSuccessResult
(
&
tokenResp
)
.
Post
(
oauth
.
TokenURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request failed: %w"
,
err
)
}
if
!
resp
.
IsSuccessState
()
{
return
nil
,
fmt
.
Errorf
(
"token refresh failed: status %d, body: %s"
,
resp
.
StatusCode
,
resp
.
String
())
}
return
&
TokenInfo
{
AccessToken
:
tokenResp
.
AccessToken
,
TokenType
:
tokenResp
.
TokenType
,
ExpiresIn
:
tokenResp
.
ExpiresIn
,
ExpiresAt
:
time
.
Now
()
.
Unix
()
+
tokenResp
.
ExpiresIn
,
RefreshToken
:
tokenResp
.
RefreshToken
,
Scope
:
tokenResp
.
Scope
,
},
nil
}
// RefreshAccountToken refreshes token for an account
func
(
s
*
OAuthService
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
model
.
Account
)
(
*
TokenInfo
,
error
)
{
refreshToken
:=
account
.
GetCredential
(
"refresh_token"
)
if
refreshToken
==
""
{
return
nil
,
fmt
.
Errorf
(
"no refresh token available"
)
}
var
proxyURL
string
if
account
.
ProxyID
!=
nil
{
proxy
,
err
:=
s
.
proxyRepo
.
GetByID
(
ctx
,
*
account
.
ProxyID
)
if
err
==
nil
&&
proxy
!=
nil
{
proxyURL
=
proxy
.
URL
()
}
}
return
s
.
RefreshToken
(
ctx
,
refreshToken
,
proxyURL
)
}
// createReqClient creates a req client with Chrome impersonation and optional proxy
func
(
s
*
OAuthService
)
createReqClient
(
proxyURL
string
)
*
req
.
Client
{
client
:=
req
.
C
()
.
ImpersonateChrome
()
.
// Impersonate Chrome browser to bypass Cloudflare
SetTimeout
(
60
*
time
.
Second
)
// Set proxy if specified
if
proxyURL
!=
""
{
client
.
SetProxyURL
(
proxyURL
)
}
return
client
}
backend/internal/service/pricing_service.go
0 → 100644
View file @
642842c2
package
service
import
(
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"sub2api/internal/config"
)
// LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值
type
LiteLLMModelPricing
struct
{
InputCostPerToken
float64
`json:"input_cost_per_token"`
OutputCostPerToken
float64
`json:"output_cost_per_token"`
CacheCreationInputTokenCost
float64
`json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost
float64
`json:"cache_read_input_token_cost"`
LiteLLMProvider
string
`json:"litellm_provider"`
Mode
string
`json:"mode"`
SupportsPromptCaching
bool
`json:"supports_prompt_caching"`
}
// LiteLLMRawEntry 用于解析原始JSON数据
type
LiteLLMRawEntry
struct
{
InputCostPerToken
*
float64
`json:"input_cost_per_token"`
OutputCostPerToken
*
float64
`json:"output_cost_per_token"`
CacheCreationInputTokenCost
*
float64
`json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost
*
float64
`json:"cache_read_input_token_cost"`
LiteLLMProvider
string
`json:"litellm_provider"`
Mode
string
`json:"mode"`
SupportsPromptCaching
bool
`json:"supports_prompt_caching"`
}
// PricingService 动态价格服务
type
PricingService
struct
{
cfg
*
config
.
Config
mu
sync
.
RWMutex
pricingData
map
[
string
]
*
LiteLLMModelPricing
lastUpdated
time
.
Time
localHash
string
// 停止信号
stopCh
chan
struct
{}
wg
sync
.
WaitGroup
}
// NewPricingService 创建价格服务
func
NewPricingService
(
cfg
*
config
.
Config
)
*
PricingService
{
s
:=
&
PricingService
{
cfg
:
cfg
,
pricingData
:
make
(
map
[
string
]
*
LiteLLMModelPricing
),
stopCh
:
make
(
chan
struct
{}),
}
return
s
}
// Initialize 初始化价格服务
func
(
s
*
PricingService
)
Initialize
()
error
{
// 确保数据目录存在
if
err
:=
os
.
MkdirAll
(
s
.
cfg
.
Pricing
.
DataDir
,
0755
);
err
!=
nil
{
log
.
Printf
(
"[Pricing] Failed to create data directory: %v"
,
err
)
}
// 首次加载价格数据
if
err
:=
s
.
checkAndUpdatePricing
();
err
!=
nil
{
log
.
Printf
(
"[Pricing] Initial load failed, using fallback: %v"
,
err
)
if
err
:=
s
.
useFallbackPricing
();
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to load pricing data: %w"
,
err
)
}
}
// 启动定时更新
s
.
startUpdateScheduler
()
log
.
Printf
(
"[Pricing] Service initialized with %d models"
,
len
(
s
.
pricingData
))
return
nil
}
// Stop 停止价格服务
func
(
s
*
PricingService
)
Stop
()
{
close
(
s
.
stopCh
)
s
.
wg
.
Wait
()
log
.
Println
(
"[Pricing] Service stopped"
)
}
// startUpdateScheduler 启动定时更新调度器
func
(
s
*
PricingService
)
startUpdateScheduler
()
{
// 定期检查哈希更新
hashInterval
:=
time
.
Duration
(
s
.
cfg
.
Pricing
.
HashCheckIntervalMinutes
)
*
time
.
Minute
if
hashInterval
<
time
.
Minute
{
hashInterval
=
10
*
time
.
Minute
}
s
.
wg
.
Add
(
1
)
go
func
()
{
defer
s
.
wg
.
Done
()
ticker
:=
time
.
NewTicker
(
hashInterval
)
defer
ticker
.
Stop
()
for
{
select
{
case
<-
ticker
.
C
:
if
err
:=
s
.
syncWithRemote
();
err
!=
nil
{
log
.
Printf
(
"[Pricing] Sync failed: %v"
,
err
)
}
case
<-
s
.
stopCh
:
return
}
}
}()
log
.
Printf
(
"[Pricing] Update scheduler started (check every %v)"
,
hashInterval
)
}
// checkAndUpdatePricing 检查并更新价格数据
func
(
s
*
PricingService
)
checkAndUpdatePricing
()
error
{
pricingFile
:=
s
.
getPricingFilePath
()
// 检查本地文件是否存在
if
_
,
err
:=
os
.
Stat
(
pricingFile
);
os
.
IsNotExist
(
err
)
{
log
.
Println
(
"[Pricing] Local pricing file not found, downloading..."
)
return
s
.
downloadPricingData
()
}
// 检查文件是否过期
info
,
err
:=
os
.
Stat
(
pricingFile
)
if
err
!=
nil
{
return
s
.
downloadPricingData
()
}
fileAge
:=
time
.
Since
(
info
.
ModTime
())
maxAge
:=
time
.
Duration
(
s
.
cfg
.
Pricing
.
UpdateIntervalHours
)
*
time
.
Hour
if
fileAge
>
maxAge
{
log
.
Printf
(
"[Pricing] Local file is %v old, updating..."
,
fileAge
.
Round
(
time
.
Hour
))
if
err
:=
s
.
downloadPricingData
();
err
!=
nil
{
log
.
Printf
(
"[Pricing] Download failed, using existing file: %v"
,
err
)
}
}
// 加载本地文件
return
s
.
loadPricingData
(
pricingFile
)
}
// syncWithRemote 与远程同步(基于哈希校验)
func
(
s
*
PricingService
)
syncWithRemote
()
error
{
pricingFile
:=
s
.
getPricingFilePath
()
// 计算本地文件哈希
localHash
,
err
:=
s
.
computeFileHash
(
pricingFile
)
if
err
!=
nil
{
log
.
Printf
(
"[Pricing] Failed to compute local hash: %v"
,
err
)
return
s
.
downloadPricingData
()
}
// 如果配置了哈希URL,从远程获取哈希进行比对
if
s
.
cfg
.
Pricing
.
HashURL
!=
""
{
remoteHash
,
err
:=
s
.
fetchRemoteHash
()
if
err
!=
nil
{
log
.
Printf
(
"[Pricing] Failed to fetch remote hash: %v"
,
err
)
return
nil
// 哈希获取失败不影响正常使用
}
if
remoteHash
!=
localHash
{
log
.
Println
(
"[Pricing] Remote hash differs, downloading new version..."
)
return
s
.
downloadPricingData
()
}
log
.
Println
(
"[Pricing] Hash check passed, no update needed"
)
return
nil
}
// 没有哈希URL时,基于时间检查
info
,
err
:=
os
.
Stat
(
pricingFile
)
if
err
!=
nil
{
return
s
.
downloadPricingData
()
}
fileAge
:=
time
.
Since
(
info
.
ModTime
())
maxAge
:=
time
.
Duration
(
s
.
cfg
.
Pricing
.
UpdateIntervalHours
)
*
time
.
Hour
if
fileAge
>
maxAge
{
log
.
Printf
(
"[Pricing] File is %v old, downloading..."
,
fileAge
.
Round
(
time
.
Hour
))
return
s
.
downloadPricingData
()
}
return
nil
}
// downloadPricingData 从远程下载价格数据
func
(
s
*
PricingService
)
downloadPricingData
()
error
{
log
.
Printf
(
"[Pricing] Downloading from %s"
,
s
.
cfg
.
Pricing
.
RemoteURL
)
client
:=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
resp
,
err
:=
client
.
Get
(
s
.
cfg
.
Pricing
.
RemoteURL
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"download failed: %w"
,
err
)
}
defer
resp
.
Body
.
Close
()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
fmt
.
Errorf
(
"download failed: HTTP %d"
,
resp
.
StatusCode
)
}
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"read response failed: %w"
,
err
)
}
// 解析JSON数据(使用灵活的解析方式)
data
,
err
:=
s
.
parsePricingData
(
body
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"parse pricing data: %w"
,
err
)
}
// 保存到本地文件
pricingFile
:=
s
.
getPricingFilePath
()
if
err
:=
os
.
WriteFile
(
pricingFile
,
body
,
0644
);
err
!=
nil
{
log
.
Printf
(
"[Pricing] Failed to save file: %v"
,
err
)
}
// 保存哈希
hash
:=
sha256
.
Sum256
(
body
)
hashStr
:=
hex
.
EncodeToString
(
hash
[
:
])
hashFile
:=
s
.
getHashFilePath
()
if
err
:=
os
.
WriteFile
(
hashFile
,
[]
byte
(
hashStr
+
"
\n
"
),
0644
);
err
!=
nil
{
log
.
Printf
(
"[Pricing] Failed to save hash: %v"
,
err
)
}
// 更新内存数据
s
.
mu
.
Lock
()
s
.
pricingData
=
data
s
.
lastUpdated
=
time
.
Now
()
s
.
localHash
=
hashStr
s
.
mu
.
Unlock
()
log
.
Printf
(
"[Pricing] Downloaded %d models successfully"
,
len
(
data
))
return
nil
}
// parsePricingData 解析价格数据(处理各种格式)
func
(
s
*
PricingService
)
parsePricingData
(
body
[]
byte
)
(
map
[
string
]
*
LiteLLMModelPricing
,
error
)
{
// 首先解析为 map[string]json.RawMessage
var
rawData
map
[
string
]
json
.
RawMessage
if
err
:=
json
.
Unmarshal
(
body
,
&
rawData
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse raw JSON: %w"
,
err
)
}
result
:=
make
(
map
[
string
]
*
LiteLLMModelPricing
)
skipped
:=
0
for
modelName
,
rawEntry
:=
range
rawData
{
// 跳过 sample_spec 等文档条目
if
modelName
==
"sample_spec"
{
continue
}
// 尝试解析每个条目
var
entry
LiteLLMRawEntry
if
err
:=
json
.
Unmarshal
(
rawEntry
,
&
entry
);
err
!=
nil
{
skipped
++
continue
}
// 只保留有有效价格的条目
if
entry
.
InputCostPerToken
==
nil
&&
entry
.
OutputCostPerToken
==
nil
{
continue
}
pricing
:=
&
LiteLLMModelPricing
{
LiteLLMProvider
:
entry
.
LiteLLMProvider
,
Mode
:
entry
.
Mode
,
SupportsPromptCaching
:
entry
.
SupportsPromptCaching
,
}
if
entry
.
InputCostPerToken
!=
nil
{
pricing
.
InputCostPerToken
=
*
entry
.
InputCostPerToken
}
if
entry
.
OutputCostPerToken
!=
nil
{
pricing
.
OutputCostPerToken
=
*
entry
.
OutputCostPerToken
}
if
entry
.
CacheCreationInputTokenCost
!=
nil
{
pricing
.
CacheCreationInputTokenCost
=
*
entry
.
CacheCreationInputTokenCost
}
if
entry
.
CacheReadInputTokenCost
!=
nil
{
pricing
.
CacheReadInputTokenCost
=
*
entry
.
CacheReadInputTokenCost
}
result
[
modelName
]
=
pricing
}
if
skipped
>
0
{
log
.
Printf
(
"[Pricing] Skipped %d invalid entries"
,
skipped
)
}
if
len
(
result
)
==
0
{
return
nil
,
fmt
.
Errorf
(
"no valid pricing entries found"
)
}
return
result
,
nil
}
// loadPricingData 从本地文件加载价格数据
func
(
s
*
PricingService
)
loadPricingData
(
filePath
string
)
error
{
data
,
err
:=
os
.
ReadFile
(
filePath
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"read file failed: %w"
,
err
)
}
// 使用灵活的解析方式
pricingData
,
err
:=
s
.
parsePricingData
(
data
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"parse pricing data: %w"
,
err
)
}
// 计算哈希
hash
:=
sha256
.
Sum256
(
data
)
hashStr
:=
hex
.
EncodeToString
(
hash
[
:
])
s
.
mu
.
Lock
()
s
.
pricingData
=
pricingData
s
.
localHash
=
hashStr
info
,
_
:=
os
.
Stat
(
filePath
)
if
info
!=
nil
{
s
.
lastUpdated
=
info
.
ModTime
()
}
else
{
s
.
lastUpdated
=
time
.
Now
()
}
s
.
mu
.
Unlock
()
log
.
Printf
(
"[Pricing] Loaded %d models from %s"
,
len
(
pricingData
),
filePath
)
return
nil
}
// useFallbackPricing 使用回退价格文件
func
(
s
*
PricingService
)
useFallbackPricing
()
error
{
fallbackFile
:=
s
.
cfg
.
Pricing
.
FallbackFile
if
_
,
err
:=
os
.
Stat
(
fallbackFile
);
os
.
IsNotExist
(
err
)
{
return
fmt
.
Errorf
(
"fallback file not found: %s"
,
fallbackFile
)
}
log
.
Printf
(
"[Pricing] Using fallback file: %s"
,
fallbackFile
)
// 复制到数据目录
data
,
err
:=
os
.
ReadFile
(
fallbackFile
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"read fallback failed: %w"
,
err
)
}
pricingFile
:=
s
.
getPricingFilePath
()
if
err
:=
os
.
WriteFile
(
pricingFile
,
data
,
0644
);
err
!=
nil
{
log
.
Printf
(
"[Pricing] Failed to copy fallback: %v"
,
err
)
}
return
s
.
loadPricingData
(
fallbackFile
)
}
// fetchRemoteHash 从远程获取哈希值
func
(
s
*
PricingService
)
fetchRemoteHash
()
(
string
,
error
)
{
client
:=
&
http
.
Client
{
Timeout
:
10
*
time
.
Second
}
resp
,
err
:=
client
.
Get
(
s
.
cfg
.
Pricing
.
HashURL
)
if
err
!=
nil
{
return
""
,
err
}
defer
resp
.
Body
.
Close
()
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
""
,
fmt
.
Errorf
(
"HTTP %d"
,
resp
.
StatusCode
)
}
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
""
,
err
}
// 哈希文件格式:hash filename 或者纯 hash
hash
:=
strings
.
TrimSpace
(
string
(
body
))
parts
:=
strings
.
Fields
(
hash
)
if
len
(
parts
)
>
0
{
return
parts
[
0
],
nil
}
return
hash
,
nil
}
// computeFileHash 计算文件哈希
func
(
s
*
PricingService
)
computeFileHash
(
filePath
string
)
(
string
,
error
)
{
data
,
err
:=
os
.
ReadFile
(
filePath
)
if
err
!=
nil
{
return
""
,
err
}
hash
:=
sha256
.
Sum256
(
data
)
return
hex
.
EncodeToString
(
hash
[
:
]),
nil
}
// GetModelPricing 获取模型价格(带模糊匹配)
func
(
s
*
PricingService
)
GetModelPricing
(
modelName
string
)
*
LiteLLMModelPricing
{
s
.
mu
.
RLock
()
defer
s
.
mu
.
RUnlock
()
if
modelName
==
""
{
return
nil
}
// 标准化模型名称
modelLower
:=
strings
.
ToLower
(
modelName
)
// 1. 精确匹配
if
pricing
,
ok
:=
s
.
pricingData
[
modelLower
];
ok
{
return
pricing
}
if
pricing
,
ok
:=
s
.
pricingData
[
modelName
];
ok
{
return
pricing
}
// 2. 处理常见的模型名称变体
// claude-opus-4-5-20251101 -> claude-opus-4.5-20251101
normalized
:=
strings
.
ReplaceAll
(
modelLower
,
"-4-5-"
,
"-4.5-"
)
if
pricing
,
ok
:=
s
.
pricingData
[
normalized
];
ok
{
return
pricing
}
// 3. 尝试模糊匹配(去掉版本号后缀)
// claude-opus-4-5-20251101 -> claude-opus-4.5
baseName
:=
s
.
extractBaseName
(
modelLower
)
for
key
,
pricing
:=
range
s
.
pricingData
{
keyBase
:=
s
.
extractBaseName
(
strings
.
ToLower
(
key
))
if
keyBase
==
baseName
{
return
pricing
}
}
// 4. 基于模型系列匹配
return
s
.
matchByModelFamily
(
modelLower
)
}
// extractBaseName 提取基础模型名称(去掉日期版本号)
func
(
s
*
PricingService
)
extractBaseName
(
model
string
)
string
{
// 移除日期后缀 (如 -20251101, -20241022)
parts
:=
strings
.
Split
(
model
,
"-"
)
result
:=
make
([]
string
,
0
,
len
(
parts
))
for
_
,
part
:=
range
parts
{
// 跳过看起来像日期的部分(8位数字)
if
len
(
part
)
==
8
&&
isNumeric
(
part
)
{
continue
}
// 跳过版本号(如 v1:0)
if
strings
.
Contains
(
part
,
":"
)
{
continue
}
result
=
append
(
result
,
part
)
}
return
strings
.
Join
(
result
,
"-"
)
}
// matchByModelFamily 基于模型系列匹配
func
(
s
*
PricingService
)
matchByModelFamily
(
model
string
)
*
LiteLLMModelPricing
{
// Claude模型系列匹配规则
familyPatterns
:=
map
[
string
][]
string
{
"opus-4.5"
:
{
"claude-opus-4.5"
,
"claude-opus-4-5"
},
"opus-4"
:
{
"claude-opus-4"
,
"claude-3-opus"
},
"sonnet-4.5"
:
{
"claude-sonnet-4.5"
,
"claude-sonnet-4-5"
},
"sonnet-4"
:
{
"claude-sonnet-4"
,
"claude-3-5-sonnet"
},
"sonnet-3.5"
:
{
"claude-3-5-sonnet"
,
"claude-3.5-sonnet"
},
"sonnet-3"
:
{
"claude-3-sonnet"
},
"haiku-3.5"
:
{
"claude-3-5-haiku"
,
"claude-3.5-haiku"
},
"haiku-3"
:
{
"claude-3-haiku"
},
}
// 确定模型属于哪个系列
var
matchedFamily
string
for
family
,
patterns
:=
range
familyPatterns
{
for
_
,
pattern
:=
range
patterns
{
if
strings
.
Contains
(
model
,
pattern
)
||
strings
.
Contains
(
model
,
strings
.
ReplaceAll
(
pattern
,
"-"
,
""
))
{
matchedFamily
=
family
break
}
}
if
matchedFamily
!=
""
{
break
}
}
if
matchedFamily
==
""
{
// 简单的系列匹配
if
strings
.
Contains
(
model
,
"opus"
)
{
if
strings
.
Contains
(
model
,
"4.5"
)
||
strings
.
Contains
(
model
,
"4-5"
)
{
matchedFamily
=
"opus-4.5"
}
else
{
matchedFamily
=
"opus-4"
}
}
else
if
strings
.
Contains
(
model
,
"sonnet"
)
{
if
strings
.
Contains
(
model
,
"4.5"
)
||
strings
.
Contains
(
model
,
"4-5"
)
{
matchedFamily
=
"sonnet-4.5"
}
else
if
strings
.
Contains
(
model
,
"3-5"
)
||
strings
.
Contains
(
model
,
"3.5"
)
{
matchedFamily
=
"sonnet-3.5"
}
else
{
matchedFamily
=
"sonnet-4"
}
}
else
if
strings
.
Contains
(
model
,
"haiku"
)
{
if
strings
.
Contains
(
model
,
"3-5"
)
||
strings
.
Contains
(
model
,
"3.5"
)
{
matchedFamily
=
"haiku-3.5"
}
else
{
matchedFamily
=
"haiku-3"
}
}
}
if
matchedFamily
==
""
{
return
nil
}
// 在价格数据中查找该系列的模型
patterns
:=
familyPatterns
[
matchedFamily
]
for
_
,
pattern
:=
range
patterns
{
for
key
,
pricing
:=
range
s
.
pricingData
{
keyLower
:=
strings
.
ToLower
(
key
)
if
strings
.
Contains
(
keyLower
,
pattern
)
{
log
.
Printf
(
"[Pricing] Fuzzy matched %s -> %s"
,
model
,
key
)
return
pricing
}
}
}
return
nil
}
// GetStatus 获取服务状态
func
(
s
*
PricingService
)
GetStatus
()
map
[
string
]
interface
{}
{
s
.
mu
.
RLock
()
defer
s
.
mu
.
RUnlock
()
return
map
[
string
]
interface
{}{
"model_count"
:
len
(
s
.
pricingData
),
"last_updated"
:
s
.
lastUpdated
,
"local_hash"
:
s
.
localHash
[
:
min
(
8
,
len
(
s
.
localHash
))],
}
}
// ForceUpdate 强制更新
func
(
s
*
PricingService
)
ForceUpdate
()
error
{
return
s
.
downloadPricingData
()
}
// getPricingFilePath 获取价格文件路径
func
(
s
*
PricingService
)
getPricingFilePath
()
string
{
return
filepath
.
Join
(
s
.
cfg
.
Pricing
.
DataDir
,
"model_pricing.json"
)
}
// getHashFilePath 获取哈希文件路径
func
(
s
*
PricingService
)
getHashFilePath
()
string
{
return
filepath
.
Join
(
s
.
cfg
.
Pricing
.
DataDir
,
"model_pricing.sha256"
)
}
// isNumeric 检查字符串是否为纯数字
func
isNumeric
(
s
string
)
bool
{
for
_
,
c
:=
range
s
{
if
c
<
'0'
||
c
>
'9'
{
return
false
}
}
return
true
}
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment