Commit 642842c2 authored by shaw's avatar shaw
Browse files

First commit

parent 569f4882
package admin
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
if p == nil {
return nil
}
return &response.PaginationResult{
Total: p.Total,
Page: p.Page,
PageSize: p.PageSize,
Pages: p.Pages,
}
}
// SubscriptionHandler handles admin subscription management
type SubscriptionHandler struct {
subscriptionService *service.SubscriptionService
}
// NewSubscriptionHandler creates a new admin subscription handler
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
return &SubscriptionHandler{
subscriptionService: subscriptionService,
}
}
// AssignSubscriptionRequest represents assign subscription request
type AssignSubscriptionRequest struct {
UserID int64 `json:"user_id" binding:"required"`
GroupID int64 `json:"group_id" binding:"required"`
ValidityDays int `json:"validity_days"`
Notes string `json:"notes"`
}
// BulkAssignSubscriptionRequest represents bulk assign subscription request
type BulkAssignSubscriptionRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required,min=1"`
GroupID int64 `json:"group_id" binding:"required"`
ValidityDays int `json:"validity_days"`
Notes string `json:"notes"`
}
// ExtendSubscriptionRequest represents extend subscription request
type ExtendSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=1"`
}
// List handles listing all subscriptions with pagination and filters
// GET /api/v1/admin/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
// Parse optional filters
var userID, groupID *int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = &id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = &id
}
}
status := c.Query("status")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
}
// GetByID handles getting a subscription by ID
// GET /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) GetByID(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
return
}
response.Success(c, subscription)
}
// GetProgress handles getting subscription usage progress
// GET /api/v1/admin/subscriptions/:id/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
return
}
response.Success(c, progress)
}
// Assign handles assigning a subscription to a user
// POST /api/v1/admin/subscriptions/assign
func (h *SubscriptionHandler) Assign(c *gin.Context) {
var req AssignSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Get admin user ID from context
adminID := getAdminIDFromContext(c)
subscription, err := h.subscriptionService.AssignSubscription(c.Request.Context(), &service.AssignSubscriptionInput{
UserID: req.UserID,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
AssignedBy: adminID,
Notes: req.Notes,
})
if err != nil {
response.BadRequest(c, "Failed to assign subscription: "+err.Error())
return
}
response.Success(c, subscription)
}
// BulkAssign handles bulk assigning subscriptions to multiple users
// POST /api/v1/admin/subscriptions/bulk-assign
func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
var req BulkAssignSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Get admin user ID from context
adminID := getAdminIDFromContext(c)
result, err := h.subscriptionService.BulkAssignSubscription(c.Request.Context(), &service.BulkAssignSubscriptionInput{
UserIDs: req.UserIDs,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
AssignedBy: adminID,
Notes: req.Notes,
})
if err != nil {
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error())
return
}
response.Success(c, result)
}
// Extend handles extending a subscription
// POST /api/v1/admin/subscriptions/:id/extend
func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
var req ExtendSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil {
response.InternalError(c, "Failed to extend subscription: "+err.Error())
return
}
response.Success(c, subscription)
}
// Revoke handles revoking a subscription
// DELETE /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
if err != nil {
response.InternalError(c, "Failed to revoke subscription: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Subscription revoked successfully"})
}
// ListByGroup handles listing subscriptions for a specific group
// GET /api/v1/admin/groups/:id/subscriptions
func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
page, pageSize := response.ParsePagination(c)
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to list group subscriptions: "+err.Error())
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
}
// ListByUser handles listing subscriptions for a specific user
// GET /api/v1/admin/users/:id/subscriptions
func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to list user subscriptions: "+err.Error())
return
}
response.Success(c, subscriptions)
}
// Helper function to get admin ID from context
func getAdminIDFromContext(c *gin.Context) int64 {
if user, exists := c.Get("user"); exists {
if u, ok := user.(*model.User); ok && u != nil {
return u.ID
}
}
return 0
}
package admin
import (
"net/http"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SystemHandler handles system-related operations
type SystemHandler struct {
updateSvc *service.UpdateService
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
return &SystemHandler{
updateSvc: service.NewUpdateService(rdb, version, buildType),
}
}
// GetVersion returns the current version
// GET /api/v1/admin/system/version
func (h *SystemHandler) GetVersion(c *gin.Context) {
info, _ := h.updateSvc.CheckUpdate(c.Request.Context(), false)
response.Success(c, gin.H{
"version": info.CurrentVersion,
})
}
// CheckUpdates checks for available updates
// GET /api/v1/admin/system/check-updates
func (h *SystemHandler) CheckUpdates(c *gin.Context) {
force := c.Query("force") == "true"
info, err := h.updateSvc.CheckUpdate(c.Request.Context(), force)
if err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, info)
}
// PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
})
}
// Rollback restores the previous version
// POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) {
if err := h.updateSvc.Rollback(); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
})
}
// RestartService restarts the systemd service
// POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) {
if err := h.updateSvc.RestartService(); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Service restart initiated",
})
}
package admin
import (
"strconv"
"time"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
func NewUsageHandler(
usageRepo *repository.UsageLogRepository,
apiKeyRepo *repository.ApiKeyRepository,
usageService *service.UsageService,
adminService service.AdminService,
) *UsageHandler {
return &UsageHandler{
usageRepo: usageRepo,
apiKeyRepo: apiKeyRepo,
usageService: usageService,
adminService: adminService,
}
}
// List handles listing all usage records with filters
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
// Parse filters
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
apiKeyID = id
}
// Parse date range
var startTime, endTime *time.Time
if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
startTime = &t
}
if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// Set end time to end of day
t = t.Add(24*time.Hour - time.Nanosecond)
endTime = &t
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,
StartTime: startTime,
EndTime: endTime,
}
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters)
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
return
}
response.Paginated(c, records, result.Total, page, pageSize)
}
// Stats handles getting usage statistics with filters
// GET /api/v1/admin/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
// Parse filters
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
apiKeyID = id
}
// Parse date range
now := timezone.Now()
var startTime, endTime time.Time
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
if startDateStr != "" && endDateStr != "" {
var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
} else {
period := c.DefaultQuery("period", "today")
switch period {
case "today":
startTime = timezone.StartOfDay(now)
case "week":
startTime = now.AddDate(0, 0, -7)
case "month":
startTime = now.AddDate(0, -1, 0)
default:
startTime = timezone.StartOfDay(now)
}
endTime = now
}
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
return
}
if userID > 0 {
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
return
}
// Get global stats
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
}
// SearchUsers handles searching users by email keyword
// GET /api/v1/admin/usage/search-users
func (h *UsageHandler) SearchUsers(c *gin.Context) {
keyword := c.Query("q")
if keyword == "" {
response.Success(c, []interface{}{})
return
}
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
if err != nil {
response.InternalError(c, "Failed to search users: "+err.Error())
return
}
// Return simplified user list (only id and email)
type SimpleUser struct {
ID int64 `json:"id"`
Email string `json:"email"`
}
result := make([]SimpleUser, len(users))
for i, u := range users {
result[i] = SimpleUser{
ID: u.ID,
Email: u.Email,
}
}
response.Success(c, result)
}
// SearchApiKeys handles searching API keys by user
// GET /api/v1/admin/usage/search-api-keys
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userIDStr := c.Query("user_id")
keyword := c.Query("q")
var userID int64
if userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error())
return
}
// Return simplified API key list (only id and name)
type SimpleApiKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
UserID int64 `json:"user_id"`
}
result := make([]SimpleApiKey, len(keys))
for i, k := range keys {
result[i] = SimpleApiKey{
ID: k.ID,
Name: k.Name,
UserID: k.UserID,
}
}
response.Success(c, result)
}
package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserHandler handles admin user management
type UserHandler struct {
adminService service.AdminService
}
// NewUserHandler creates a new admin user handler
func NewUserHandler(adminService service.AdminService) *UserHandler {
return &UserHandler{
adminService: adminService,
}
}
// CreateUserRequest represents admin create user request
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
}
// UpdateUserRequest represents admin update user request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
}
// UpdateBalanceRequest represents balance update request
type UpdateBalanceRequest struct {
Balance float64 `json:"balance" binding:"required"`
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
}
// List handles listing all users with pagination
// GET /api/v1/admin/users
func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
role := c.Query("role")
search := c.Query("search")
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
if err != nil {
response.InternalError(c, "Failed to list users: "+err.Error())
return
}
response.Paginated(c, users, total, page, pageSize)
}
// GetByID handles getting a user by ID
// GET /api/v1/admin/users/:id
func (h *UserHandler) GetByID(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
user, err := h.adminService.GetUser(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "User not found")
return
}
response.Success(c, user)
}
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
var req CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
Email: req.Email,
Password: req.Password,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.BadRequest(c, "Failed to create user: "+err.Error())
return
}
response.Success(c, user)
}
// Update handles updating a user
// PUT /api/v1/admin/users/:id
func (h *UserHandler) Update(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 使用指针类型直接传递,nil 表示未提供该字段
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.InternalError(c, "Failed to update user: "+err.Error())
return
}
response.Success(c, user)
}
// Delete handles deleting a user
// DELETE /api/v1/admin/users/:id
func (h *UserHandler) Delete(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
err = h.adminService.DeleteUser(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to delete user: "+err.Error())
return
}
response.Success(c, gin.H{"message": "User deleted successfully"})
}
// UpdateBalance handles updating user balance
// POST /api/v1/admin/users/:id/balance
func (h *UserHandler) UpdateBalance(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateBalanceRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error())
return
}
response.Success(c, user)
}
// GetUserAPIKeys handles getting user's API keys
// GET /api/v1/admin/users/:id/api-keys
func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
page, pageSize := response.ParsePagination(c)
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get user API keys: "+err.Error())
return
}
response.Paginated(c, keys, total, page, pageSize)
}
// GetUserUsage handles getting user's usage statistics
// GET /api/v1/admin/users/:id/usage
func (h *UserHandler) GetUserUsage(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
period := c.DefaultQuery("period", "month")
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
if err != nil {
response.InternalError(c, "Failed to get user usage: "+err.Error())
return
}
response.Success(c, stats)
}
package handler
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// APIKeyHandler handles API key-related requests
type APIKeyHandler struct {
apiKeyService *service.ApiKeyService
}
// NewAPIKeyHandler creates a new APIKeyHandler
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
return &APIKeyHandler{
apiKeyService: apiKeyService,
}
}
// CreateAPIKeyRequest represents the create API key request payload
type CreateAPIKeyRequest struct {
Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateAPIKeyRequest represents the update API key request payload
type UpdateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
}
// List handles listing user's API keys with pagination
// GET /api/v1/api-keys
func (h *APIKeyHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
page, pageSize := response.ParsePagination(c)
params := repository.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil {
response.InternalError(c, "Failed to list API keys: "+err.Error())
return
}
response.Paginated(c, keys, result.Total, page, pageSize)
}
// GetByID handles getting a single API key
// GET /api/v1/api-keys/:id
func (h *APIKeyHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid key ID")
return
}
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
if err != nil {
response.NotFound(c, "API key not found")
return
}
// 验证所有权
if key.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this key")
return
}
response.Success(c, key)
}
// Create handles creating a new API key
// POST /api/v1/api-keys
func (h *APIKeyHandler) Create(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req CreateAPIKeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.CreateApiKeyRequest{
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
}
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.InternalError(c, "Failed to create API key: "+err.Error())
return
}
response.Success(c, key)
}
// Update handles updating an API key
// PUT /api/v1/api-keys/:id
func (h *APIKeyHandler) Update(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid key ID")
return
}
var req UpdateAPIKeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.UpdateApiKeyRequest{}
if req.Name != "" {
svcReq.Name = &req.Name
}
svcReq.GroupID = req.GroupID
if req.Status != "" {
svcReq.Status = &req.Status
}
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
if err != nil {
response.InternalError(c, "Failed to update API key: "+err.Error())
return
}
response.Success(c, key)
}
// Delete handles deleting an API key
// DELETE /api/v1/api-keys/:id
func (h *APIKeyHandler) Delete(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid key ID")
return
}
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
if err != nil {
response.InternalError(c, "Failed to delete API key: "+err.Error())
return
}
response.Success(c, gin.H{"message": "API key deleted successfully"})
}
// GetAvailableGroups 获取用户可以绑定的分组列表
// GET /api/v1/groups/available
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get available groups: "+err.Error())
return
}
response.Success(c, groups)
}
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AuthHandler handles authentication-related requests
type AuthHandler struct {
authService *service.AuthService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// RegisterRequest represents the registration request payload
type RegisterRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"`
}
// SendVerifyCodeRequest 发送验证码请求
type SendVerifyCodeRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token"`
}
// SendVerifyCodeResponse 发送验证码响应
type SendVerifyCodeResponse struct {
Message string `json:"message"`
Countdown int `json:"countdown"` // 倒计时秒数
}
// LoginRequest represents the login request payload
type LoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
TurnstileToken string `json:"turnstile_token"`
}
// AuthResponse 认证响应格式(匹配前端期望)
type AuthResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *model.User `json:"user"`
}
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
return
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
if err != nil {
response.BadRequest(c, "Registration failed: "+err.Error())
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
})
}
// SendVerifyCode 发送邮箱验证码
// POST /api/v1/auth/send-verify-code
func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
var req SendVerifyCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
return
}
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
if err != nil {
response.BadRequest(c, "Failed to send verification code: "+err.Error())
return
}
response.Success(c, SendVerifyCodeResponse{
Message: "Verification code sent successfully",
Countdown: result.Countdown,
})
}
// Login handles user login
// POST /api/v1/auth/login
func (h *AuthHandler) Login(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
return
}
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil {
response.Unauthorized(c, "Login failed: "+err.Error())
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
})
}
// GetCurrentUser handles getting current authenticated user
// GET /api/v1/auth/me
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
response.Success(c, user)
}
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// Maximum wait time for concurrency slot
maxConcurrencyWait = 60 * time.Second
// Ping interval during wait
pingInterval = 5 * time.Second
)
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
userService *service.UserService
concurrencyService *service.ConcurrencyService
billingCacheService *service.BillingCacheService
}
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(gatewayService *service.GatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService) *GatewayHandler {
return &GatewayHandler{
gatewayService: gatewayService,
userService: userService,
concurrencyService: concurrencyService,
billingCacheService: billingCacheService,
}
}
// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// 解析请求获取模型名和stream
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Track if we've started streaming (for error handling)
streamStarted := false
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
// 确保在函数退出时减少wait计数
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
}
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if err != nil {
// 错误响应已在Forward中处理,这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量(subscription已在函数开头获取)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
}
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
}
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
}
// concurrencyError represents a concurrency limit error with context
type concurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *concurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
// Note: For streaming requests, we send ping to keep the connection alive.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
defer cancel()
// For streaming requests, set up SSE headers for ping
var flusher http.Flusher
if isStream {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return nil, &concurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingTicker.C:
// Send ping for streaming requests to keep connection alive
if isStream && flusher != nil {
// Set headers on first ping (lazy initialization)
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
flusher.Flush()
}
case <-pollTicker.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
}
}
// Models handles listing available models
// GET /v1/models
func (h *GatewayHandler) Models(c *gin.Context) {
models := []gin.H{
{
"id": "claude-opus-4-5-20251101",
"type": "model",
"display_name": "Claude Opus 4.5",
"created_at": "2025-11-01T00:00:00Z",
},
{
"id": "claude-sonnet-4-5-20250929",
"type": "model",
"display_name": "Claude Sonnet 4.5",
"created_at": "2025-09-29T00:00:00Z",
},
{
"id": "claude-haiku-4-5-20251001",
"type": "model",
"display_name": "Claude Haiku 4.5",
"created_at": "2025-10-01T00:00:00Z",
},
}
c.JSON(http.StatusOK, gin.H{
"data": models,
"object": "list",
})
}
// Usage handles getting account balance for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
// 订阅模式:返回订阅限额信息
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware.GetSubscriptionFromContext(c)
if !ok {
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
return
}
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
c.JSON(http.StatusOK, gin.H{
"isValid": true,
"planName": apiKey.Group.Name,
"remaining": remaining,
"unit": "USD",
})
return
}
// 余额模式:返回钱包余额
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
}
c.JSON(http.StatusOK, gin.H{
"isValid": true,
"planName": "钱包余额",
"remaining": latestUser.Balance,
"unit": "USD",
})
}
// calculateSubscriptionRemaining 计算订阅剩余可用额度
// 逻辑:
// 1. 如果日/周/月任一限额达到100%,返回0
// 2. 否则返回所有已配置周期中剩余额度的最小值
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 {
var remainingValues []float64
// 检查日限额
if group.HasDailyLimit() {
remaining := *group.DailyLimitUSD - sub.DailyUsageUSD
if remaining <= 0 {
return 0
}
remainingValues = append(remainingValues, remaining)
}
// 检查周限额
if group.HasWeeklyLimit() {
remaining := *group.WeeklyLimitUSD - sub.WeeklyUsageUSD
if remaining <= 0 {
return 0
}
remainingValues = append(remainingValues, remaining)
}
// 检查月限额
if group.HasMonthlyLimit() {
remaining := *group.MonthlyLimitUSD - sub.MonthlyUsageUSD
if remaining <= 0 {
return 0
}
remainingValues = append(remainingValues, remaining)
}
// 如果没有配置任何限额,返回-1表示无限制
if len(remainingValues) == 0 {
return -1
}
// 返回最小值
min := remainingValues[0]
for _, v := range remainingValues[1:] {
if v < min {
min = v
}
}
return min
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in SSE format
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
fmt.Fprint(c.Writer, errorEvent)
flusher.Flush()
}
return
}
// Normal case: return JSON response with proper status code
h.errorResponse(c, status, errType, message)
}
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}
package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// AdminHandlers contains all admin-related HTTP handlers
type AdminHandlers struct {
Dashboard *admin.DashboardHandler
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler
OAuth *admin.OAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Setting *admin.SettingHandler
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
}
// Handlers contains all HTTP handlers
type Handlers struct {
Auth *AuthHandler
User *UserHandler
APIKey *APIKeyHandler
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Admin *AdminHandlers
Gateway *GatewayHandler
Setting *SettingHandler
}
// BuildInfo contains build-time information
type BuildInfo struct {
Version string
BuildType string // "source" for manual builds, "release" for CI builds
}
// NewHandlers creates a new Handlers instance with all handlers initialized
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
return &Handlers{
Auth: NewAuthHandler(services.Auth),
User: NewUserHandler(services.User),
APIKey: NewAPIKeyHandler(services.ApiKey),
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
Redeem: NewRedeemHandler(services.Redeem),
Subscription: NewSubscriptionHandler(services.Subscription),
Admin: &AdminHandlers{
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
User: admin.NewUserHandler(services.Admin),
Group: admin.NewGroupHandler(services.Admin),
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
Proxy: admin.NewProxyHandler(services.Admin),
Redeem: admin.NewRedeemHandler(services.Admin),
Setting: admin.NewSettingHandler(services.Setting, services.Email),
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
Subscription: admin.NewSubscriptionHandler(services.Subscription),
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
},
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
}
}
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RedeemHandler handles redeem code-related requests
type RedeemHandler struct {
redeemService *service.RedeemService
}
// NewRedeemHandler creates a new RedeemHandler
func NewRedeemHandler(redeemService *service.RedeemService) *RedeemHandler {
return &RedeemHandler{
redeemService: redeemService,
}
}
// RedeemRequest represents the redeem code request payload
type RedeemRequest struct {
Code string `json:"code" binding:"required"`
}
// RedeemResponse represents the redeem response
type RedeemResponse struct {
Message string `json:"message"`
Type string `json:"type"`
Value float64 `json:"value"`
NewBalance *float64 `json:"new_balance,omitempty"`
NewConcurrency *int `json:"new_concurrency,omitempty"`
}
// Redeem handles redeeming a code
// POST /api/v1/redeem
func (h *RedeemHandler) Redeem(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req RedeemRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
if err != nil {
response.BadRequest(c, "Failed to redeem code: "+err.Error())
return
}
response.Success(c, result)
}
// GetHistory returns the user's redemption history
// GET /api/v1/redeem/history
func (h *RedeemHandler) GetHistory(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
// Default limit is 25
limit := 25
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
if err != nil {
response.InternalError(c, "Failed to get history: "+err.Error())
return
}
response.Success(c, codes)
}
package handler
import (
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SettingHandler 公开设置处理器(无需认证)
type SettingHandler struct {
settingService *service.SettingService
version string
}
// NewSettingHandler 创建公开设置处理器
func NewSettingHandler(settingService *service.SettingService, version string) *SettingHandler {
return &SettingHandler{
settingService: settingService,
version: version,
}
}
// GetPublicSettings 获取公开设置
// GET /api/v1/settings/public
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
return
}
settings.Version = h.version
response.Success(c, settings)
}
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SubscriptionSummaryItem represents a subscription item in summary
type SubscriptionSummaryItem struct {
ID int64 `json:"id"`
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Status string `json:"status"`
DailyUsedUSD float64 `json:"daily_used_usd,omitempty"`
DailyLimitUSD float64 `json:"daily_limit_usd,omitempty"`
WeeklyUsedUSD float64 `json:"weekly_used_usd,omitempty"`
WeeklyLimitUSD float64 `json:"weekly_limit_usd,omitempty"`
MonthlyUsedUSD float64 `json:"monthly_used_usd,omitempty"`
MonthlyLimitUSD float64 `json:"monthly_limit_usd,omitempty"`
ExpiresAt *string `json:"expires_at,omitempty"`
}
// SubscriptionProgressInfo represents subscription with progress info
type SubscriptionProgressInfo struct {
Subscription *model.UserSubscription `json:"subscription"`
Progress *service.SubscriptionProgress `json:"progress"`
}
// SubscriptionHandler handles user subscription operations
type SubscriptionHandler struct {
subscriptionService *service.SubscriptionService
}
// NewSubscriptionHandler creates a new user subscription handler
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
return &SubscriptionHandler{
subscriptionService: subscriptionService,
}
}
// List handles listing current user's subscriptions
// GET /api/v1/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
return
}
response.Success(c, subscriptions)
}
// GetActive handles getting current user's active subscriptions
// GET /api/v1/subscriptions/active
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get active subscriptions: "+err.Error())
return
}
response.Success(c, subscriptions)
}
// GetProgress handles getting subscription progress for current user
// GET /api/v1/subscriptions/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
return
}
result := make([]SubscriptionProgressInfo, 0, len(subscriptions))
for i := range subscriptions {
sub := &subscriptions[i]
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), sub.ID)
if err != nil {
// Skip subscriptions with errors
continue
}
result = append(result, SubscriptionProgressInfo{
Subscription: sub,
Progress: progress,
})
}
response.Success(c, result)
}
// GetSummary handles getting a summary of current user's subscription status
// GET /api/v1/subscriptions/summary
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
return
}
var totalUsed float64
items := make([]SubscriptionSummaryItem, 0, len(subscriptions))
for _, sub := range subscriptions {
item := SubscriptionSummaryItem{
ID: sub.ID,
GroupID: sub.GroupID,
Status: sub.Status,
DailyUsedUSD: sub.DailyUsageUSD,
WeeklyUsedUSD: sub.WeeklyUsageUSD,
MonthlyUsedUSD: sub.MonthlyUsageUSD,
}
// Add group info if preloaded
if sub.Group != nil {
item.GroupName = sub.Group.Name
if sub.Group.DailyLimitUSD != nil {
item.DailyLimitUSD = *sub.Group.DailyLimitUSD
}
if sub.Group.WeeklyLimitUSD != nil {
item.WeeklyLimitUSD = *sub.Group.WeeklyLimitUSD
}
if sub.Group.MonthlyLimitUSD != nil {
item.MonthlyLimitUSD = *sub.Group.MonthlyLimitUSD
}
}
// Format expiration time
if !sub.ExpiresAt.IsZero() {
formatted := sub.ExpiresAt.Format("2006-01-02T15:04:05Z07:00")
item.ExpiresAt = &formatted
}
// Track total usage (use monthly as the most comprehensive)
totalUsed += sub.MonthlyUsageUSD
items = append(items, item)
}
summary := struct {
ActiveCount int `json:"active_count"`
TotalUsedUSD float64 `json:"total_used_usd"`
Subscriptions []SubscriptionSummaryItem `json:"subscriptions"`
}{
ActiveCount: len(subscriptions),
TotalUsedUSD: totalUsed,
Subscriptions: items,
}
response.Success(c, summary)
}
package handler
import (
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UsageHandler handles usage-related requests
type UsageHandler struct {
usageService *service.UsageService
usageRepo *repository.UsageLogRepository
apiKeyService *service.ApiKeyService
}
// NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler {
return &UsageHandler{
usageService: usageService,
usageRepo: usageRepo,
apiKeyService: apiKeyService,
}
}
// List handles listing usage records with pagination
// GET /api/v1/usage
func (h *UsageHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
page, pageSize := response.ParsePagination(c)
var apiKeyID int64
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "API key not found")
return
}
if apiKey.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this API key's usage records")
return
}
apiKeyID = id
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog
var result *repository.PaginationResult
var err error
if apiKeyID > 0 {
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
} else {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
}
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
return
}
response.Paginated(c, records, result.Total, page, pageSize)
}
// GetByID handles getting a single usage record
// GET /api/v1/usage/:id
func (h *UsageHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
usageID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid usage ID")
return
}
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
if err != nil {
response.NotFound(c, "Usage record not found")
return
}
// 验证所有权
if record.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this record")
return
}
response.Success(c, record)
}
// Stats handles getting usage statistics
// GET /api/v1/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var apiKeyID int64
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "API key not found")
return
}
if apiKey.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this API key's statistics")
return
}
apiKeyID = id
}
// 获取时间范围参数
now := timezone.Now()
var startTime, endTime time.Time
// 优先使用 start_date 和 end_date 参数
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
if startDateStr != "" && endDateStr != "" {
// 使用自定义日期范围
var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// 设置结束时间为当天结束
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
} else {
// 使用 period 参数
period := c.DefaultQuery("period", "today")
switch period {
case "today":
startTime = timezone.StartOfDay(now)
case "week":
startTime = now.AddDate(0, 0, -7)
case "month":
startTime = now.AddDate(0, -1, 0)
default:
startTime = timezone.StartOfDay(now)
}
endTime = now
}
var stats *service.UsageStats
var err error
if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
}
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
}
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now()
startDate := c.Query("start_date")
endDate := c.Query("end_date")
var startTime, endTime time.Time
if startDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
startTime = t
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
if endDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
return startTime, endTime
}
// DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get dashboard statistics")
return
}
response.Success(c, stats)
}
// DashboardTrend handles getting user usage trend data
// GET /api/v1/usage/dashboard/trend
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
if err != nil {
response.InternalError(c, "Failed to get usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// DashboardModels handles getting user model usage statistics
// GET /api/v1/usage/dashboard/models
func (h *UsageHandler) DashboardModels(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get model statistics")
return
}
response.Success(c, gin.H{
"models": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// BatchApiKeysUsageRequest represents the request for batch API keys usage
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req BatchApiKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
// Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
if err != nil {
response.InternalError(c, "Failed to verify API key ownership")
return
}
userApiKeyIDs := make(map[int64]bool)
for _, key := range userApiKeys {
userApiKeyIDs[key.ID] = true
}
// Filter to only include user's own API keys
validApiKeyIDs := make([]int64, 0)
for _, id := range req.ApiKeyIDs {
if userApiKeyIDs[id] {
validApiKeyIDs = append(validApiKeyIDs, id)
}
}
if len(validApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil {
response.InternalError(c, "Failed to get API key usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
}
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
}
// NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{
userService: userService,
}
}
// ChangePasswordRequest represents the change password request payload
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get user profile: "+err.Error())
return
}
response.Success(c, userData)
}
// ChangePassword handles changing user password
// POST /api/v1/users/me/password
func (h *UserHandler) ChangePassword(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req ChangePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.ChangePasswordRequest{
CurrentPassword: req.OldPassword,
NewPassword: req.NewPassword,
}
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to change password: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Password changed successfully"})
}
package middleware
import (
"sub2api/internal/model"
"github.com/gin-gonic/gin"
)
// AdminOnly 管理员权限中间件
// 必须在JWTAuth中间件之后使用
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
// 从上下文获取用户
user, exists := GetUserFromContext(c)
if !exists {
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
return
}
// 检查是否为管理员
if user.Role != model.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return
}
c.Next()
}
}
package middleware
import (
"context"
"errors"
"log"
"strings"
"sub2api/internal/model"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// ApiKeyAuthService 定义API Key认证服务需要的接口
type ApiKeyAuthService interface {
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
}
// SubscriptionAuthService 定义订阅认证服务需要的接口
type SubscriptionAuthService interface {
GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error
CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error
CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error
CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error
}
// ApiKeyAuth API Key认证中间件
func ApiKeyAuth(apiKeyRepo ApiKeyAuthService) gin.HandlerFunc {
return ApiKeyAuthWithSubscription(apiKeyRepo, nil)
}
// ApiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func ApiKeyAuthWithSubscription(apiKeyRepo ApiKeyAuthService, subscriptionService SubscriptionAuthService) gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
if authHeader != "" {
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
apiKeyString = parts[1]
}
}
// 如果Authorization header中没有,尝试从x-api-key header中提取
if apiKeyString == "" {
apiKeyString = c.GetHeader("x-api-key")
}
// 如果两个header都没有API key
if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme) or x-api-key header")
return
}
// 从数据库验证API key
apiKey, err := apiKeyRepo.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return
}
// 检查API key是否激活
if !apiKey.IsActive() {
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return
}
// 检查关联的用户
if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return
}
// 检查用户状态
if !apiKey.User.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
apiKey.Group.ID,
)
if err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
// 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
// 激活滑动窗口(首次使用时)
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
// 检查并重置过期窗口
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 预检查用量限制(使用0作为额外费用进行预检查)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), apiKey.User)
c.Next()
}
}
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*model.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*model.ApiKey)
return apiKey, ok
}
// GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*model.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription))
if !exists {
return nil, false
}
subscription, ok := value.(*model.UserSubscription)
return subscription, ok
}
package middleware
import (
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置允许跨域的响应头
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}
package middleware
import (
"context"
"strings"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// JWTAuth JWT认证中间件
func JWTAuth(authService *service.AuthService, userRepo interface {
GetByID(ctx context.Context, id int64) (*model.User, error)
}) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization header is required")
return
}
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
return
}
tokenString := parts[1]
if tokenString == "" {
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
return
}
// 验证token
claims, err := authService.ValidateToken(tokenString)
if err != nil {
if err == service.ErrTokenExpired {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return
}
// 从数据库获取最新的用户信息
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
// 将用户信息存入上下文
c.Set(string(ContextKeyUser), user)
c.Next()
}
}
// GetUserFromContext 从上下文中获取用户
func GetUserFromContext(c *gin.Context) (*model.User, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return nil, false
}
user, ok := value.(*model.User)
return user, ok
}
package middleware
import (
"log"
"time"
"github.com/gin-gonic/gin"
)
// Logger 请求日志中间件
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
// 开始时间
startTime := time.Now()
// 处理请求
c.Next()
// 结束时间
endTime := time.Now()
// 执行时间
latency := endTime.Sub(startTime)
// 请求方法
method := c.Request.Method
// 请求路径
path := c.Request.URL.Path
// 状态码
statusCode := c.Writer.Status()
// 客户端IP
clientIP := c.ClientIP()
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"),
statusCode,
latency,
clientIP,
method,
path,
)
// 如果有错误,额外记录错误信息
if len(c.Errors) > 0 {
log.Printf("[GIN] Errors: %v", c.Errors.String())
}
}
}
package middleware
import "github.com/gin-gonic/gin"
// ContextKey 定义上下文键类型
type ContextKey string
const (
// ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
)
// ErrorResponse 标准错误响应结构
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
// NewErrorResponse 创建错误响应
func NewErrorResponse(code, message string) ErrorResponse {
return ErrorResponse{
Code: code,
Message: message,
}
}
// AbortWithError 中断请求并返回JSON错误
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, NewErrorResponse(code, message))
c.Abort()
}
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"time"
"gorm.io/gorm"
)
// JSONB 用于存储JSONB数据
type JSONB map[string]interface{}
func (j JSONB) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
func (j *JSONB) Scan(value interface{}) error {
if value == nil {
*j = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, j)
}
type Account struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 调度控制
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
// 限流状态 (429)
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
// 过载状态 (529)
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
// 5小时时间窗口
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
// 关联
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
}
func (Account) TableName() string {
return "accounts"
}
// IsActive 检查是否激活
func (a *Account) IsActive() bool {
return a.Status == "active"
}
// IsSchedulable 检查账号是否可调度
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
}
now := time.Now()
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
return true
}
// IsRateLimited 检查是否处于限流状态
func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil {
return false
}
return time.Now().Before(*a.RateLimitResetAt)
}
// IsOverloaded 检查是否处于过载状态
func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil {
return false
}
return time.Now().Before(*a.OverloadUntil)
}
// IsOAuth 检查是否为OAuth类型账号(包括oauth和setup-token)
func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
// CanGetUsage 检查账号是否可以获取usage信息(只有oauth类型可以,setup-token没有profile权限)
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
// GetCredential 获取凭证字段
func (a *Account) GetCredential(key string) string {
if a.Credentials == nil {
return ""
}
if v, ok := a.Credentials[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// GetModelMapping 获取模型映射配置
// 返回格式: map[请求模型名]实际模型名
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
return nil
}
// 处理map[string]interface{}类型
if m, ok := raw.(map[string]interface{}); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
return result
}
}
return nil
}
// IsModelSupported 检查请求的模型是否被该账号支持
// 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
return true // 没有映射配置,支持所有模型
}
_, exists := mapping[requestedModel]
return exists
}
// GetMappedModel 获取映射后的实际模型名
// 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
return requestedModel
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
}
return requestedModel
}
// GetBaseURL 获取API基础URL(用于apikey类型账号)
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
return "https://api.anthropic.com" // 默认URL
}
return baseURL
}
// GetExtraString 从Extra字段获取字符串值
func (a *Account) GetExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCustomErrorCodes 获取自定义错误码列表
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["custom_error_codes"]
if !ok || raw == nil {
return nil
}
// 处理 []interface{} 类型(JSON反序列化后的格式)
if arr, ok := raw.([]interface{}); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
// JSON 数字默认解析为 float64
if f, ok := v.(float64); ok {
result = append(result, int(f))
}
}
return result
}
return nil
}
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
// 如果未启用自定义错误码或列表为空,返回 true(使用默认策略)
// 如果启用且列表非空,只有在列表中的错误码才返回 true
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() {
return true // 未启用,使用默认策略
}
codes := a.GetCustomErrorCodes()
if len(codes) == 0 {
return true // 启用但列表为空,fallback到默认策略
}
// 检查是否在自定义列表中
for _, code := range codes {
if code == statusCode {
return true
}
}
return false
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment