Commit bfcd9501 authored by IanShaw027's avatar IanShaw027
Browse files

merge: 合并 upstream/main 解决 PR #37 冲突

- 删除 backend/internal/model/account.go 符合重构方向
- 合并最新的项目结构重构
- 包含 SSE 格式解析修复
- 更新依赖和配置文件
parents 9780f0fd 12252c60
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -69,11 +68,11 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64 ...@@ -69,11 +68,11 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency) result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -83,17 +82,17 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model. ...@@ -83,17 +82,17 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted) return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
} }
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary. // AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency) result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -103,7 +102,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account * ...@@ -103,7 +102,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted) return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
} }
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
......
...@@ -46,7 +46,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -46,7 +46,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
user, ok := middleware2.GetUserFromContext(c) subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
...@@ -94,8 +94,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -94,8 +94,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full // 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(user.Concurrency) maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed // On error, allow request to proceed
...@@ -104,10 +104,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -104,10 +104,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
// Ensure wait count is decremented when function exits // Ensure wait count is decremented when function exits
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID) defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. First acquire user concurrency slot // 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
...@@ -118,7 +118,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -118,7 +118,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
// 2. Re-check billing eligibility after wait // 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return return
...@@ -138,7 +138,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -138,7 +138,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot // 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted) accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
...@@ -163,7 +163,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -163,7 +163,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
ApiKey: apiKey, ApiKey: apiKey,
User: user, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
}); err != nil { }); err != nil {
......
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -37,15 +38,9 @@ type RedeemResponse struct { ...@@ -37,15 +38,9 @@ type RedeemResponse struct {
// Redeem handles redeeming a code // Redeem handles redeeming a code
// POST /api/v1/redeem // POST /api/v1/redeem
func (h *RedeemHandler) Redeem(c *gin.Context) { func (h *RedeemHandler) Redeem(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) { ...@@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
return return
} }
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code) result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, result) response.Success(c, dto.RedeemCodeFromService(result))
} }
// GetHistory returns the user's redemption history // GetHistory returns the user's redemption history
// GET /api/v1/redeem/history // GET /api/v1/redeem/history
func (h *RedeemHandler) GetHistory(c *gin.Context) { func (h *RedeemHandler) GetHistory(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
// Default limit is 25 // Default limit is 25
limit := 25 limit := 25
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit) codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, codes) out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Success(c, out)
} }
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -30,6 +31,17 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -30,6 +31,17 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
return return
} }
settings.Version = h.version response.Success(c, dto.PublicSettings{
response.Success(c, settings) RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
Version: h.version,
})
} }
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct { ...@@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct {
// SubscriptionProgressInfo represents subscription with progress info // SubscriptionProgressInfo represents subscription with progress info
type SubscriptionProgressInfo struct { type SubscriptionProgressInfo struct {
Subscription *model.UserSubscription `json:"subscription"` Subscription *dto.UserSubscription `json:"subscription"`
Progress *service.SubscriptionProgress `json:"progress"` Progress *service.SubscriptionProgress `json:"progress"`
} }
...@@ -44,68 +45,58 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S ...@@ -44,68 +45,58 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S
// List handles listing current user's subscriptions // List handles listing current user's subscriptions
// GET /api/v1/subscriptions // GET /api/v1/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) { func (h *SubscriptionHandler) List(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user in context") response.Unauthorized(c, "User not found in context")
return return
} }
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, subscriptions) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
} }
// GetActive handles getting current user's active subscriptions // GetActive handles getting current user's active subscriptions
// GET /api/v1/subscriptions/active // GET /api/v1/subscriptions/active
func (h *SubscriptionHandler) GetActive(c *gin.Context) { func (h *SubscriptionHandler) GetActive(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user in context") response.Unauthorized(c, "User not found in context")
return return
} }
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, subscriptions) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
} }
// GetProgress handles getting subscription progress for current user // GetProgress handles getting subscription progress for current user
// GET /api/v1/subscriptions/progress // GET /api/v1/subscriptions/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) { func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user in context") response.Unauthorized(c, "User not found in context")
return return
} }
// Get all active subscriptions with progress // Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) { ...@@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
continue continue
} }
result = append(result, SubscriptionProgressInfo{ result = append(result, SubscriptionProgressInfo{
Subscription: sub, Subscription: dto.UserSubscriptionFromService(sub),
Progress: progress, Progress: progress,
}) })
} }
...@@ -131,20 +122,14 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) { ...@@ -131,20 +122,14 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
// GetSummary handles getting a summary of current user's subscription status // GetSummary handles getting a summary of current user's subscription status
// GET /api/v1/subscriptions/summary // GET /api/v1/subscriptions/summary
func (h *SubscriptionHandler) GetSummary(c *gin.Context) { func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user in context") response.Unauthorized(c, "User not found in context")
return return
} }
// Get all active subscriptions // Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -4,10 +4,11 @@ import ( ...@@ -4,10 +4,11 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -30,15 +31,9 @@ func NewUsageHandler(usageService *service.UsageService, apiKeyService *service. ...@@ -30,15 +31,9 @@ func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.
// List handles listing usage records with pagination // List handles listing usage records with pagination
// GET /api/v1/usage // GET /api/v1/usage
func (h *UsageHandler) List(c *gin.Context) { func (h *UsageHandler) List(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -58,7 +53,7 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -58,7 +53,7 @@ func (h *UsageHandler) List(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if apiKey.UserID != user.ID { if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's usage records") response.Forbidden(c, "Not authorized to access this API key's usage records")
return return
} }
...@@ -67,35 +62,33 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -67,35 +62,33 @@ func (h *UsageHandler) List(c *gin.Context) {
} }
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog var records []service.UsageLog
var result *pagination.PaginationResult var result *pagination.PaginationResult
var err error var err error
if apiKeyID > 0 { if apiKeyID > 0 {
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params) records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
} else { } else {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params) records, result, err = h.usageService.ListByUser(c.Request.Context(), subject.UserID, params)
} }
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Paginated(c, records, result.Total, page, pageSize) out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
} }
// GetByID handles getting a single usage record // GetByID handles getting a single usage record
// GET /api/v1/usage/:id // GET /api/v1/usage/:id
func (h *UsageHandler) GetByID(c *gin.Context) { func (h *UsageHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -112,26 +105,20 @@ func (h *UsageHandler) GetByID(c *gin.Context) { ...@@ -112,26 +105,20 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
} }
// 验证所有权 // 验证所有权
if record.UserID != user.ID { if record.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this record") response.Forbidden(c, "Not authorized to access this record")
return return
} }
response.Success(c, record) response.Success(c, dto.UsageLogFromService(record))
} }
// Stats handles getting usage statistics // Stats handles getting usage statistics
// GET /api/v1/usage/stats // GET /api/v1/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) { func (h *UsageHandler) Stats(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -149,7 +136,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -149,7 +136,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.NotFound(c, "API key not found") response.NotFound(c, "API key not found")
return return
} }
if apiKey.UserID != user.ID { if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's statistics") response.Forbidden(c, "Not authorized to access this API key's statistics")
return return
} }
...@@ -201,7 +188,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -201,7 +188,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 { if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime) stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else { } else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime) stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
} }
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -245,19 +232,13 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) { ...@@ -245,19 +232,13 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
// DashboardStats handles getting user dashboard statistics // DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats // GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) { func (h *UsageHandler) DashboardStats(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID) stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -269,22 +250,16 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) { ...@@ -269,22 +250,16 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
// DashboardTrend handles getting user usage trend data // DashboardTrend handles getting user usage trend data
// GET /api/v1/usage/dashboard/trend // GET /api/v1/usage/dashboard/trend
func (h *UsageHandler) DashboardTrend(c *gin.Context) { func (h *UsageHandler) DashboardTrend(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day") granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity) trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -301,21 +276,15 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) { ...@@ -301,21 +276,15 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
// DashboardModels handles getting user model usage statistics // DashboardModels handles getting user model usage statistics
// GET /api/v1/usage/dashboard/models // GET /api/v1/usage/dashboard/models
func (h *UsageHandler) DashboardModels(c *gin.Context) { func (h *UsageHandler) DashboardModels(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime) stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -336,15 +305,9 @@ type BatchApiKeysUsageRequest struct { ...@@ -336,15 +305,9 @@ type BatchApiKeysUsageRequest struct {
// DashboardApiKeysUsage handles getting usage stats for user's own API keys // DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage // POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -360,7 +323,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { ...@@ -360,7 +323,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
} }
// Verify ownership of all requested API keys // Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000}) userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -35,19 +36,13 @@ type UpdateProfileRequest struct { ...@@ -35,19 +36,13 @@ type UpdateProfileRequest struct {
// GetProfile handles getting user profile // GetProfile handles getting user profile
// GET /api/v1/users/me // GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) { func (h *UserHandler) GetProfile(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
userData, err := h.userService.GetByID(c.Request.Context(), user.ID) userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) { ...@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
// 清空notes字段,普通用户不应看到备注 // 清空notes字段,普通用户不应看到备注
userData.Notes = "" userData.Notes = ""
response.Success(c, userData) response.Success(c, dto.UserFromService(userData))
} }
// ChangePassword handles changing user password // ChangePassword handles changing user password
// POST /api/v1/users/me/password // POST /api/v1/users/me/password
func (h *UserHandler) ChangePassword(c *gin.Context) { func (h *UserHandler) ChangePassword(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { ...@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
CurrentPassword: req.OldPassword, CurrentPassword: req.OldPassword,
NewPassword: req.NewPassword, NewPassword: req.NewPassword,
} }
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq) err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { ...@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
// UpdateProfile handles updating user profile // UpdateProfile handles updating user profile
// PUT /api/v1/users/me // PUT /api/v1/users/me
func (h *UserHandler) UpdateProfile(c *gin.Context) { func (h *UserHandler) UpdateProfile(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
Username: req.Username, Username: req.Username,
Wechat: req.Wechat, Wechat: req.Wechat,
} }
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq) updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
// 清空notes字段,普通用户不应看到备注 // 清空notes字段,普通用户不应看到备注
updatedUser.Notes = "" updatedUser.Notes = ""
response.Success(c, updatedUser) response.Success(c, dto.UserFromService(updatedUser))
} }
...@@ -2,8 +2,8 @@ package infrastructure ...@@ -2,8 +2,8 @@ package infrastructure
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) { ...@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 自动迁移(始终执行,确保数据库结构与代码同步) // 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的 // GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil { if err := repository.AutoMigrate(db); err != nil {
return nil, err return nil, err
} }
......
package model
import (
"time"
)
type AccountGroup struct {
AccountID int64 `gorm:"primaryKey" json:"account_id"`
GroupID int64 `gorm:"primaryKey" json:"group_id"`
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 关联
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (AccountGroup) TableName() string {
return "account_groups"
}
package model
import (
"time"
"gorm.io/gorm"
)
type ApiKey struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
Name string `gorm:"size:100;not null" json:"name"`
GroupID *int64 `gorm:"index" json:"group_id"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (ApiKey) TableName() string {
return "api_keys"
}
// IsActive 检查是否激活
func (k *ApiKey) IsActive() bool {
return k.Status == "active"
}
package model
import (
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
type Group struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
}
func (Group) TableName() string {
return "groups"
}
// IsActive 检查是否激活
func (g *Group) IsActive() bool {
return g.Status == "active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
// HasDailyLimit 检查是否有日限额
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
// HasWeeklyLimit 检查是否有周限额
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
// HasMonthlyLimit 检查是否有月限额
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}
package model
import (
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&User{},
&ApiKey{},
&Group{},
&Account{},
&AccountGroup{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&Setting{},
&UserSubscription{},
)
}
// 状态常量
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// 角色常量
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// 平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// 账号类型常量
const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeApiKey = "apikey" // API Key类型账号
)
// 卡密类型常量
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// 管理员调整类型常量
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
package model
import (
"fmt"
"time"
"gorm.io/gorm"
)
type Proxy struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
Host string `gorm:"size:255;not null" json:"host"`
Port int `gorm:"not null" json:"port"`
Username string `gorm:"size:100" json:"username"`
Password string `gorm:"size:100" json:"-"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (Proxy) TableName() string {
return "proxies"
}
// IsActive 检查是否激活
func (p *Proxy) IsActive() bool {
return p.Status == "active"
}
// URL 返回代理URL
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
// ProxyWithAccountCount extends Proxy with account count information
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}
package model
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (RedeemCode) TableName() string {
return "redeem_codes"
}
// IsUsed 检查是否已使用
func (r *RedeemCode) IsUsed() bool {
return r.Status == "used"
}
// CanUse 检查是否可以使用
func (r *RedeemCode) CanUse() bool {
return r.Status == "unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
package model
import (
"time"
)
// 消费类型常量
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
AccountID int64 `gorm:"index;not null" json:"account_id"`
RequestID string `gorm:"size:64" json:"request_id"`
Model string `gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID *int64 `gorm:"index" json:"group_id"`
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
// Token使用量(4类)
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用(USD)
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
// 元数据
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
Stream bool `gorm:"default:false;not null" json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func (UsageLog) TableName() string {
return "usage_logs"
}
// TotalTokens 总token数
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}
package model
import (
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
Username string `gorm:"size:100;default:''" json:"username"`
Wechat string `gorm:"size:100;default:''" json:"wechat"`
Notes string `gorm:"type:text;default:''" json:"notes"`
PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
}
func (User) TableName() string {
return "users"
}
// IsAdmin 检查是否管理员
func (u *User) IsAdmin() bool {
return u.Role == "admin"
}
// IsActive 检查是否激活
func (u *User) IsActive() bool {
return u.Status == "active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return !isExclusive
}
// SetPassword 设置密码(哈希存储)
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
// CheckPassword 验证密码
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
return err == nil
}
This diff is collapsed.
...@@ -2,6 +2,7 @@ package repository ...@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
...@@ -14,6 +15,11 @@ const ( ...@@ -14,6 +15,11 @@ const (
apiKeyRateLimitDuration = 24 * time.Hour apiKeyRateLimitDuration = 24 * time.Hour
) )
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
type apiKeyCache struct { type apiKeyCache struct {
rdb *redis.Client rdb *redis.Client
} }
...@@ -23,12 +29,16 @@ func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache { ...@@ -23,12 +29,16 @@ func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
} }
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) key := apiKeyRateLimitKey(userID)
return c.rdb.Get(ctx, key).Int() count, err := c.rdb.Get(ctx, key).Int()
if errors.Is(err, redis.Nil) {
return 0, nil
}
return count, err
} }
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) key := apiKeyRateLimitKey(userID)
pipe := c.rdb.Pipeline() pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key) pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration) pipe.Expire(ctx, key, apiKeyRateLimitDuration)
...@@ -37,7 +47,7 @@ func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID in ...@@ -37,7 +47,7 @@ func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID in
} }
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) key := apiKeyRateLimitKey(userID)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
......
...@@ -23,13 +23,14 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() { ...@@ -23,13 +23,14 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{ }{
{ {
name: "missing_key_returns_redis_nil", name: "missing_key_returns_zero_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) { fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1) userID := int64(1)
_, err := cache.GetCreateAttemptCount(ctx, userID) count, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key") require.NoError(s.T(), err, "expected nil error for missing key")
require.Equal(s.T(), 0, count, "expected zero count for missing key")
}, },
}, },
{ {
...@@ -58,8 +59,9 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() { ...@@ -58,8 +59,9 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID)) require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount") require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
_, err := cache.GetCreateAttemptCount(ctx, userID) count, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete") require.NoError(s.T(), err, "expected nil error after delete")
require.Equal(s.T(), 0, count, "expected zero count after delete")
}, },
}, },
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment