Commit eeaff85e authored by Forest's avatar Forest
Browse files

refactor: 自定义业务错误

parent f51ad2e1
// nolint:mnd
package errors
import "net/http"
// BadRequest new BadRequest error that is mapped to a 400 response.
func BadRequest(reason, message string) *ApplicationError {
return New(http.StatusBadRequest, reason, message)
}
// IsBadRequest determines if err is an error which indicates a BadRequest error.
// It supports wrapped errors.
func IsBadRequest(err error) bool {
return Code(err) == http.StatusBadRequest
}
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
func TooManyRequests(reason, message string) *ApplicationError {
return New(http.StatusTooManyRequests, reason, message)
}
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
// It supports wrapped errors.
func IsTooManyRequests(err error) bool {
return Code(err) == http.StatusTooManyRequests
}
// Unauthorized new Unauthorized error that is mapped to a 401 response.
func Unauthorized(reason, message string) *ApplicationError {
return New(http.StatusUnauthorized, reason, message)
}
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
// It supports wrapped errors.
func IsUnauthorized(err error) bool {
return Code(err) == http.StatusUnauthorized
}
// Forbidden new Forbidden error that is mapped to a 403 response.
func Forbidden(reason, message string) *ApplicationError {
return New(http.StatusForbidden, reason, message)
}
// IsForbidden determines if err is an error which indicates a Forbidden error.
// It supports wrapped errors.
func IsForbidden(err error) bool {
return Code(err) == http.StatusForbidden
}
// NotFound new NotFound error that is mapped to a 404 response.
func NotFound(reason, message string) *ApplicationError {
return New(http.StatusNotFound, reason, message)
}
// IsNotFound determines if err is an error which indicates an NotFound error.
// It supports wrapped errors.
func IsNotFound(err error) bool {
return Code(err) == http.StatusNotFound
}
// Conflict new Conflict error that is mapped to a 409 response.
func Conflict(reason, message string) *ApplicationError {
return New(http.StatusConflict, reason, message)
}
// IsConflict determines if err is an error which indicates a Conflict error.
// It supports wrapped errors.
func IsConflict(err error) bool {
return Code(err) == http.StatusConflict
}
// InternalServer new InternalServer error that is mapped to a 500 response.
func InternalServer(reason, message string) *ApplicationError {
return New(http.StatusInternalServerError, reason, message)
}
// IsInternalServer determines if err is an error which indicates an Internal error.
// It supports wrapped errors.
func IsInternalServer(err error) bool {
return Code(err) == http.StatusInternalServerError
}
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
func ServiceUnavailable(reason, message string) *ApplicationError {
return New(http.StatusServiceUnavailable, reason, message)
}
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
// It supports wrapped errors.
func IsServiceUnavailable(err error) bool {
return Code(err) == http.StatusServiceUnavailable
}
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
func GatewayTimeout(reason, message string) *ApplicationError {
return New(http.StatusGatewayTimeout, reason, message)
}
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
// It supports wrapped errors.
func IsGatewayTimeout(err error) bool {
return Code(err) == http.StatusGatewayTimeout
}
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
func ClientClosed(reason, message string) *ApplicationError {
return New(499, reason, message)
}
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
// It supports wrapped errors.
func IsClientClosed(err error) bool {
return Code(err) == 499
}
...@@ -3,9 +3,11 @@ package middleware ...@@ -3,9 +3,11 @@ package middleware
import ( import (
"context" "context"
"crypto/subtle" "crypto/subtle"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -96,7 +98,7 @@ func validateJWTForAdmin( ...@@ -96,7 +98,7 @@ func validateJWTForAdmin(
// 验证 JWT token // 验证 JWT token
claims, err := authService.ValidateToken(token) claims, err := authService.ValidateToken(token)
if err != nil { if err != nil {
if err == service.ErrTokenExpired { if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired") AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return false return false
} }
......
...@@ -2,9 +2,11 @@ package middleware ...@@ -2,9 +2,11 @@ package middleware
import ( import (
"context" "context"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -37,7 +39,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface { ...@@ -37,7 +39,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface {
// 验证token // 验证token
claims, err := authService.ValidateToken(tokenString) claims, err := authService.ValidateToken(tokenString)
if err != nil { if err != nil {
if err == service.ErrTokenExpired { if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired") AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return return
} }
......
...@@ -4,14 +4,17 @@ import ( ...@@ -4,14 +4,17 @@ import (
"math" "math"
"net/http" "net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// Response 标准API响应格式 // Response 标准API响应格式
type Response struct { type Response struct {
Code int `json:"code"` Code int `json:"code"`
Message string `json:"message"` Message string `json:"message"`
Data any `json:"data,omitempty"` Reason string `json:"reason,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Data any `json:"data,omitempty"`
} }
// PaginatedData 分页数据格式(匹配前端期望) // PaginatedData 分页数据格式(匹配前端期望)
...@@ -44,11 +47,36 @@ func Created(c *gin.Context, data any) { ...@@ -44,11 +47,36 @@ func Created(c *gin.Context, data any) {
// Error 返回错误响应 // Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) { func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{ c.JSON(statusCode, Response{
Code: statusCode, Code: statusCode,
Message: message, Message: message,
Reason: "",
Metadata: nil,
}) })
} }
// ErrorWithDetails returns an error response compatible with the existing envelope while
// optionally providing structured error fields (reason/metadata).
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: reason,
Metadata: metadata,
})
}
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
// It returns true if an error was written.
func ErrorFrom(c *gin.Context, err error) bool {
if err == nil {
return false
}
statusCode, status := infraerrors.ToHTTP(err)
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
// BadRequest 返回400错误 // BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) { func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message) Error(c, http.StatusBadRequest, message)
......
//go:build unit
package response
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
reason string
metadata map[string]string
want Response
}{
{
name: "plain_error",
statusCode: http.StatusBadRequest,
message: "invalid request",
want: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
},
},
{
name: "structured_error",
statusCode: http.StatusForbidden,
message: "no access",
reason: "FORBIDDEN",
metadata: map[string]string{"k": "v"},
want: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"k": "v"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
require.Equal(t, tt.statusCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.want, got)
})
}
}
func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
err error
wantWritten bool
wantHTTPCode int
wantBody Response
}{
{
name: "nil_error",
err: nil,
wantWritten: false,
},
{
name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"scope": "admin"},
},
},
{
name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
Reason: "INVALID_REQUEST",
},
},
{
name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
Code: http.StatusUnauthorized,
Message: "unauthorized",
Reason: "UNAUTHORIZED",
},
},
{
name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
Code: http.StatusNotFound,
Message: "not found",
Reason: "NOT_FOUND",
},
},
{
name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
Code: http.StatusConflict,
Message: "conflict",
Reason: "CONFLICT",
},
},
{
name: "unknown_error_defaults_to_500",
err: errors.New("boom"),
wantWritten: true,
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: "boom",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
if !tt.wantWritten {
require.Equal(t, 200, w.Code)
require.Empty(t, w.Body.String())
return
}
require.Equal(t, tt.wantHTTPCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}
...@@ -13,23 +13,23 @@ import ( ...@@ -13,23 +13,23 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
type AccountRepository struct { type accountRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewAccountRepository(db *gorm.DB) *AccountRepository { func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &AccountRepository{db: db} return &accountRepository{db: db}
} }
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error { func (r *accountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error return r.db.WithContext(ctx).Create(account).Error
} }
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
} }
// 填充 GroupIDs 和 Groups 虚拟字段 // 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups)) account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
...@@ -43,7 +43,7 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou ...@@ -43,7 +43,7 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou
return &account, nil return &account, nil
} }
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) { func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
if crsAccountID == "" { if crsAccountID == "" {
return nil, nil return nil, nil
} }
...@@ -59,11 +59,11 @@ func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID ...@@ -59,11 +59,11 @@ func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return &account, nil return &account, nil
} }
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error { func (r *accountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error return r.db.WithContext(ctx).Save(account).Error
} }
func (r *AccountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系 // 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
...@@ -72,12 +72,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error { ...@@ -72,12 +72,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
} }
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "") return r.ListWithFilters(ctx, params, "", "", "", "")
} }
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query // ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) { func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account var accounts []model.Account
var total int64 var total int64
...@@ -131,7 +131,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati ...@@ -131,7 +131,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
}, nil }, nil
} }
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
...@@ -142,7 +142,7 @@ func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]m ...@@ -142,7 +142,7 @@ func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]m
return accounts, err return accounts, err
} }
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) { func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", model.StatusActive).
...@@ -152,12 +152,12 @@ func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, er ...@@ -152,12 +152,12 @@ func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, er
return accounts, err return accounts, err
} }
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error { func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
} }
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"status": model.StatusError, "status": model.StatusError,
...@@ -165,7 +165,7 @@ func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg str ...@@ -165,7 +165,7 @@ func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg str
}).Error }).Error
} }
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{ ag := &model.AccountGroup{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
...@@ -174,12 +174,12 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i ...@@ -174,12 +174,12 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i
return r.db.WithContext(ctx).Create(ag).Error return r.db.WithContext(ctx).Create(ag).Error
} }
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID). return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error Delete(&model.AccountGroup{}).Error
} }
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) { func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group var groups []model.Group
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id"). Joins("JOIN account_groups ON account_groups.group_id = groups.id").
...@@ -188,7 +188,7 @@ func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]m ...@@ -188,7 +188,7 @@ func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]m
return groups, err return groups, err
} }
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive). Where("platform = ? AND status = ?", platform, model.StatusActive).
...@@ -198,7 +198,7 @@ func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ...@@ -198,7 +198,7 @@ func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string)
return accounts, err return accounts, err
} }
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定 // 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil { if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
...@@ -221,7 +221,7 @@ func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, gro ...@@ -221,7 +221,7 @@ func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, gro
} }
// ListSchedulable 获取所有可调度的账号 // ListSchedulable 获取所有可调度的账号
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) { func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
...@@ -235,7 +235,7 @@ func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Accoun ...@@ -235,7 +235,7 @@ func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Accoun
} }
// ListSchedulableByGroupID 按组获取可调度的账号 // ListSchedulableByGroupID 按组获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) { func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
...@@ -251,7 +251,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI ...@@ -251,7 +251,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
} }
// ListSchedulableByPlatform 按平台获取可调度的账号 // ListSchedulableByPlatform 按平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) { func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
...@@ -266,7 +266,7 @@ func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platf ...@@ -266,7 +266,7 @@ func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platf
} }
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号 // ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) { func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
...@@ -283,7 +283,7 @@ func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont ...@@ -283,7 +283,7 @@ func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
} }
// SetRateLimited 标记账号为限流状态(429) // SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
...@@ -293,13 +293,13 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA ...@@ -293,13 +293,13 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
} }
// SetOverloaded 标记账号为过载状态(529) // SetOverloaded 标记账号为过载状态(529)
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("overload_until", until).Error Update("overload_until", until).Error
} }
// ClearRateLimit 清除账号的限流状态 // ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error { func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"rate_limited_at": nil, "rate_limited_at": nil,
...@@ -309,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error ...@@ -309,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
} }
// UpdateSessionWindow 更新账号的5小时时间窗口信息 // UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{ updates := map[string]any{
"session_window_status": status, "session_window_status": status,
} }
...@@ -323,14 +323,14 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s ...@@ -323,14 +323,14 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
} }
// SetSchedulable 设置账号的调度开关 // SetSchedulable 设置账号的调度开关
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error Update("schedulable", schedulable).Error
} }
// UpdateExtra updates specific fields in account's Extra JSONB field // UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields // It merges the updates into existing Extra data without overwriting other fields
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 { if len(updates) == 0 {
return nil return nil
} }
...@@ -358,7 +358,7 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m ...@@ -358,7 +358,7 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m
// BulkUpdate updates multiple accounts with the provided fields. // BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them. // It merges credentials/extra JSONB fields instead of overwriting them.
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 { if len(ids) == 0 {
return 0, nil return 0, nil
} }
......
...@@ -18,13 +18,13 @@ type AccountRepoSuite struct { ...@@ -18,13 +18,13 @@ type AccountRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *AccountRepository repo *accountRepository
} }
func (s *AccountRepoSuite) SetupTest() { func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db) s.repo = NewAccountRepository(s.db).(*accountRepository)
} }
func TestAccountRepoSuite(t *testing.T) { func TestAccountRepoSuite(t *testing.T) {
...@@ -167,7 +167,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -167,7 +167,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源 // 每个 case 重新获取隔离资源
db := testTx(s.T()) db := testTx(s.T())
repo := NewAccountRepository(db) repo := NewAccountRepository(db).(*accountRepository)
ctx := context.Background() ctx := context.Background()
tt.setup(db) tt.setup(db)
......
...@@ -2,51 +2,55 @@ package repository ...@@ -2,51 +2,55 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ApiKeyRepository struct { type apiKeyRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository { func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &ApiKeyRepository{db: db} return &apiKeyRepository{db: db}
} }
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Create(key).Error err := r.db.WithContext(ctx).Create(key).Error
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
} }
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &key, nil return &key, nil
} }
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &apiKey, nil return &apiKey, nil
} }
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
} }
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error { func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
} }
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
...@@ -73,19 +77,19 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param ...@@ -73,19 +77,19 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
}, nil }, nil
} }
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
return count, err return count, err
} }
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
...@@ -113,7 +117,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par ...@@ -113,7 +117,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
} }
// SearchApiKeys searches API keys by user ID and/or keyword (name) // SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey var keys []model.ApiKey
db := r.db.WithContext(ctx).Model(&model.ApiKey{}) db := r.db.WithContext(ctx).Model(&model.ApiKey{})
...@@ -135,7 +139,7 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw ...@@ -135,7 +139,7 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
} }
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}). result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
Update("group_id", nil) Update("group_id", nil)
...@@ -143,7 +147,7 @@ func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in ...@@ -143,7 +147,7 @@ func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
} }
// CountByGroupID 获取分组的 API Key 数量 // CountByGroupID 获取分组的 API Key 数量
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
......
...@@ -16,13 +16,13 @@ type ApiKeyRepoSuite struct { ...@@ -16,13 +16,13 @@ type ApiKeyRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *ApiKeyRepository repo *apiKeyRepository
} }
func (s *ApiKeyRepoSuite) SetupTest() { func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db) s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
} }
func TestApiKeyRepoSuite(t *testing.T) { func TestApiKeyRepoSuite(t *testing.T) {
......
package repository
import (
"errors"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"gorm.io/gorm"
)
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
if err == nil {
return nil
}
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return notFound.WithCause(err)
}
if conflict != nil && isUniqueConstraintViolation(err) {
return conflict.WithCause(err)
}
return err
}
func isUniqueConstraintViolation(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrDuplicatedKey) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate key") ||
strings.Contains(msg, "unique constraint") ||
strings.Contains(msg, "duplicate entry")
}
...@@ -2,47 +2,52 @@ package repository ...@@ -2,47 +2,52 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type GroupRepository struct { type groupRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewGroupRepository(db *gorm.DB) *GroupRepository { func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &GroupRepository{db: db} return &groupRepository{db: db}
} }
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error { func (r *groupRepository) Create(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Create(group).Error err := r.db.WithContext(ctx).Create(group).Error
return translatePersistenceError(err, nil, service.ErrGroupExists)
} }
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
var group model.Group var group model.Group
err := r.db.WithContext(ctx).First(&group, id).Error err := r.db.WithContext(ctx).First(&group, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
return &group, nil return &group, nil
} }
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error { func (r *groupRepository) Update(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Save(group).Error return r.db.WithContext(ctx).Save(group).Error
} }
func (r *GroupRepository) Delete(ctx context.Context, id int64) error { func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
} }
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", nil)
} }
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group var groups []model.Group
var total int64 var total int64
...@@ -86,7 +91,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -86,7 +91,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination
}, nil }, nil
} }
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) { func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
var groups []model.Group var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
...@@ -100,7 +105,7 @@ func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) ...@@ -100,7 +105,7 @@ func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error)
return groups, nil return groups, nil
} }
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) { func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
var groups []model.Group var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
...@@ -114,25 +119,80 @@ func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform str ...@@ -114,25 +119,80 @@ func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform str
return groups, nil return groups, nil
} }
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) { func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
} }
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系 // DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{}) result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// DB 返回底层数据库连接,用于事务处理 func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
func (r *GroupRepository) DB() *gorm.DB { group, err := r.GetByID(ctx, id)
return r.db if err != nil {
return nil, err
}
var affectedUserIDs []int64
if group.IsSubscriptionType() {
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx).
Model(&model.UserSubscription{}).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err != nil {
return nil, err
}
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return err
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return affectedUserIDs, nil
} }
...@@ -16,13 +16,13 @@ type GroupRepoSuite struct { ...@@ -16,13 +16,13 @@ type GroupRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *GroupRepository repo *groupRepository
} }
func (s *GroupRepoSuite) SetupTest() { func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewGroupRepository(s.db) s.repo = NewGroupRepository(s.db).(*groupRepository)
} }
func TestGroupRepoSuite(t *testing.T) { func TestGroupRepoSuite(t *testing.T) {
...@@ -234,11 +234,3 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { ...@@ -234,11 +234,3 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count, _ := s.repo.GetAccountCount(s.ctx, g.ID) count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count) s.Require().Zero(count)
} }
// --- DB ---
func (s *GroupRepoSuite) TestDB() {
db := s.repo.DB()
s.Require().NotNil(db, "DB should return non-nil")
s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB")
}
...@@ -2,47 +2,50 @@ package repository ...@@ -2,47 +2,50 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ProxyRepository struct { type proxyRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewProxyRepository(db *gorm.DB) *ProxyRepository { func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &ProxyRepository{db: db} return &proxyRepository{db: db}
} }
func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error return r.db.WithContext(ctx).Create(proxy).Error
} }
func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
var proxy model.Proxy var proxy model.Proxy
err := r.db.WithContext(ctx).First(&proxy, id).Error err := r.db.WithContext(ctx).First(&proxy, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
} }
return &proxy, nil return &proxy, nil
} }
func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error return r.db.WithContext(ctx).Save(proxy).Error
} }
func (r *ProxyRepository) Delete(ctx context.Context, id int64) error { func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
} }
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query // ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy var proxies []model.Proxy
var total int64 var total int64
...@@ -81,14 +84,14 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -81,14 +84,14 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination
}, nil }, nil
} }
func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
var proxies []model.Proxy var proxies []model.Proxy
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
return proxies, err return proxies, err
} }
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists // ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}). err := r.db.WithContext(ctx).Model(&model.Proxy{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password). Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
...@@ -100,7 +103,7 @@ func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, ...@@ -100,7 +103,7 @@ func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
} }
// CountAccountsByProxyID returns the number of accounts using a specific proxy // CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}). err := r.db.WithContext(ctx).Model(&model.Account{}).
Where("proxy_id = ?", proxyID). Where("proxy_id = ?", proxyID).
...@@ -109,7 +112,7 @@ func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in ...@@ -109,7 +112,7 @@ func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
} }
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies // GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) { func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
type result struct { type result struct {
ProxyID int64 `gorm:"column:proxy_id"` ProxyID int64 `gorm:"column:proxy_id"`
Count int64 `gorm:"column:count"` Count int64 `gorm:"column:count"`
...@@ -133,7 +136,7 @@ func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i ...@@ -133,7 +136,7 @@ func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
} }
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending // ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
var proxies []model.Proxy var proxies []model.Proxy
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", model.StatusActive).
......
...@@ -17,13 +17,13 @@ type ProxyRepoSuite struct { ...@@ -17,13 +17,13 @@ type ProxyRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *ProxyRepository repo *proxyRepository
} }
func (s *ProxyRepoSuite) SetupTest() { func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewProxyRepository(s.db) s.repo = NewProxyRepository(s.db).(*proxyRepository)
} }
func TestProxyRepoSuite(t *testing.T) { func TestProxyRepoSuite(t *testing.T) {
......
...@@ -2,57 +2,60 @@ package repository ...@@ -2,57 +2,60 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm" "gorm.io/gorm"
) )
type RedeemCodeRepository struct { type redeemCodeRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewRedeemCodeRepository(db *gorm.DB) *RedeemCodeRepository { func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &RedeemCodeRepository{db: db} return &redeemCodeRepository{db: db}
} }
func (r *RedeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error return r.db.WithContext(ctx).Create(code).Error
} }
func (r *RedeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error { func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error return r.db.WithContext(ctx).Create(&codes).Error
} }
func (r *RedeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
var code model.RedeemCode var code model.RedeemCode
err := r.db.WithContext(ctx).First(&code, id).Error err := r.db.WithContext(ctx).First(&code, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &code, nil return &code, nil
} }
func (r *RedeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
var redeemCode model.RedeemCode var redeemCode model.RedeemCode
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &redeemCode, nil return &redeemCode, nil
} }
func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error { func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
} }
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query // ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode var codes []model.RedeemCode
var total int64 var total int64
...@@ -91,11 +94,11 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagin ...@@ -91,11 +94,11 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
}, nil }, nil
} }
func (r *RedeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error return r.db.WithContext(ctx).Save(code).Error
} }
func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error { func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now() now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}). result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused). Where("id = ? AND status = ?", id, model.StatusUnused).
...@@ -108,13 +111,13 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error ...@@ -108,13 +111,13 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 兑换码不存在或已被使用 return service.ErrRedeemCodeUsed.WithCause(gorm.ErrRecordNotFound)
} }
return nil return nil
} }
// ListByUser returns all redeem codes used by a specific user // ListByUser returns all redeem codes used by a specific user
func (r *RedeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) { func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode var codes []model.RedeemCode
if limit <= 0 { if limit <= 0 {
limit = 10 limit = 10
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -17,13 +18,13 @@ type RedeemCodeRepoSuite struct { ...@@ -17,13 +18,13 @@ type RedeemCodeRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *RedeemCodeRepository repo *redeemCodeRepository
} }
func (s *RedeemCodeRepoSuite) SetupTest() { func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewRedeemCodeRepository(s.db) s.repo = NewRedeemCodeRepository(s.db).(*redeemCodeRepository)
} }
func TestRedeemCodeRepoSuite(t *testing.T) { func TestRedeemCodeRepoSuite(t *testing.T) {
...@@ -195,7 +196,7 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { ...@@ -195,7 +196,7 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
// Second use should fail // Second use should fail
err = s.repo.Use(s.ctx, code.ID, user.ID) err = s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "Use expected error on second call") s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
} }
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
...@@ -204,7 +205,7 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { ...@@ -204,7 +205,7 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code") s.Require().Error(err, "expected error for already used code")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
} }
// --- ListByUser --- // --- ListByUser ---
...@@ -298,7 +299,7 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser ...@@ -298,7 +299,7 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use") s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
err = s.repo.Use(s.ctx, codeB.ID, user.ID) err = s.repo.Use(s.ctx, codeB.ID, user.ID)
s.Require().Error(err, "Use expected error on second call") s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
codeA, err := s.repo.GetByCode(s.ctx, "CODEA") codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
s.Require().NoError(err, "GetByCode") s.Require().NoError(err, "GetByCode")
......
package repository package repository
import "github.com/Wei-Shaw/sub2api/internal/service"
// Repositories 所有仓库的集合 // Repositories 所有仓库的集合
type Repositories struct { type Repositories struct {
User *UserRepository User service.UserRepository
ApiKey *ApiKeyRepository ApiKey service.ApiKeyRepository
Group *GroupRepository Group service.GroupRepository
Account *AccountRepository Account service.AccountRepository
Proxy *ProxyRepository Proxy service.ProxyRepository
RedeemCode *RedeemCodeRepository RedeemCode service.RedeemCodeRepository
UsageLog *UsageLogRepository UsageLog service.UsageLogRepository
Setting *SettingRepository Setting service.SettingRepository
UserSubscription *UserSubscriptionRepository UserSubscription service.UserSubscriptionRepository
} }
...@@ -2,35 +2,38 @@ package repository ...@@ -2,35 +2,38 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/model"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// SettingRepository 系统设置数据访问层 // SettingRepository 系统设置数据访问层
type SettingRepository struct { type settingRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewSettingRepository 创建系统设置仓库实例 // NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) *SettingRepository { func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &SettingRepository{db: db} return &settingRepository{db: db}
} }
// Get 根据Key获取设置值 // Get 根据Key获取设置值
func (r *SettingRepository) Get(ctx context.Context, key string) (*model.Setting, error) { func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
var setting model.Setting var setting model.Setting
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
} }
return &setting, nil return &setting, nil
} }
// GetValue 获取设置值字符串 // GetValue 获取设置值字符串
func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, error) { func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key) setting, err := r.Get(ctx, key)
if err != nil { if err != nil {
return "", err return "", err
...@@ -39,7 +42,7 @@ func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, e ...@@ -39,7 +42,7 @@ func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, e
} }
// Set 设置值(存在则更新,不存在则创建) // Set 设置值(存在则更新,不存在则创建)
func (r *SettingRepository) Set(ctx context.Context, key, value string) error { func (r *settingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{ setting := &model.Setting{
Key: key, Key: key,
Value: value, Value: value,
...@@ -53,7 +56,7 @@ func (r *SettingRepository) Set(ctx context.Context, key, value string) error { ...@@ -53,7 +56,7 @@ func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
} }
// GetMultiple 批量获取设置 // GetMultiple 批量获取设置
func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting var settings []model.Setting
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil { if err != nil {
...@@ -68,7 +71,7 @@ func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map ...@@ -68,7 +71,7 @@ func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map
} }
// SetMultiple 批量设置值 // SetMultiple 批量设置值
func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string]string) error { func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings { for key, value := range settings {
setting := &model.Setting{ setting := &model.Setting{
...@@ -88,7 +91,7 @@ func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string ...@@ -88,7 +91,7 @@ func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string
} }
// GetAll 获取所有设置 // GetAll 获取所有设置
func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, error) { func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting var settings []model.Setting
err := r.db.WithContext(ctx).Find(&settings).Error err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil { if err != nil {
...@@ -103,6 +106,6 @@ func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, erro ...@@ -103,6 +106,6 @@ func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, erro
} }
// Delete 删除设置 // Delete 删除设置
func (r *SettingRepository) Delete(ctx context.Context, key string) error { func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
} }
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"context" "context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -14,13 +15,13 @@ type SettingRepoSuite struct { ...@@ -14,13 +15,13 @@ type SettingRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *SettingRepository repo *settingRepository
} }
func (s *SettingRepoSuite) SetupTest() { func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewSettingRepository(s.db) s.repo = NewSettingRepository(s.db).(*settingRepository)
} }
func TestSettingRepoSuite(t *testing.T) { func TestSettingRepoSuite(t *testing.T) {
...@@ -45,7 +46,7 @@ func (s *SettingRepoSuite) TestSet_Upsert() { ...@@ -45,7 +46,7 @@ func (s *SettingRepoSuite) TestSet_Upsert() {
func (s *SettingRepoSuite) TestGetValue_Missing() { func (s *SettingRepoSuite) TestGetValue_Missing() {
_, err := s.repo.GetValue(s.ctx, "nonexistent") _, err := s.repo.GetValue(s.ctx, "nonexistent")
s.Require().Error(err, "expected error for missing key") s.Require().Error(err, "expected error for missing key")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrSettingNotFound)
} }
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() { func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
...@@ -86,7 +87,7 @@ func (s *SettingRepoSuite) TestDelete() { ...@@ -86,7 +87,7 @@ func (s *SettingRepoSuite) TestDelete() {
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete") s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
_, err := s.repo.GetValue(s.ctx, "todelete") _, err := s.repo.GetValue(s.ctx, "todelete")
s.Require().Error(err, "expected missing key error after Delete") s.Require().Error(err, "expected missing key error after Delete")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrSettingNotFound)
} }
func (s *SettingRepoSuite) TestDelete_Idempotent() { func (s *SettingRepoSuite) TestDelete_Idempotent() {
......
...@@ -2,25 +2,28 @@ package repository ...@@ -2,25 +2,28 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"time"
"gorm.io/gorm" "gorm.io/gorm"
) )
type UsageLogRepository struct { type usageLogRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository { func NewUsageLogRepository(db *gorm.DB) service.UsageLogRepository {
return &UsageLogRepository{db: db} return &usageLogRepository{db: db}
} }
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤) // getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) { func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute) fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
var perfStats struct { var perfStats struct {
RequestCount int64 `gorm:"column:request_count"` RequestCount int64 `gorm:"column:request_count"`
...@@ -43,20 +46,20 @@ func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int ...@@ -43,20 +46,20 @@ func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int
return perfStats.RequestCount / 5, perfStats.TokenCount / 5 return perfStats.RequestCount / 5, perfStats.TokenCount / 5
} }
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error return r.db.WithContext(ctx).Create(log).Error
} }
func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
var log model.UsageLog var log model.UsageLog
err := r.db.WithContext(ctx).First(&log, id).Error err := r.db.WithContext(ctx).First(&log, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
} }
return &log, nil return &log, nil
} }
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
...@@ -83,7 +86,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param ...@@ -83,7 +86,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
}, nil }, nil
} }
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
...@@ -120,7 +123,7 @@ type UserStats struct { ...@@ -120,7 +123,7 @@ type UserStats struct {
CacheReadTokens int64 `json:"cache_read_tokens"` CacheReadTokens int64 `json:"cache_read_tokens"`
} }
func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
var stats UserStats var stats UserStats
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(` Select(`
...@@ -139,7 +142,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta ...@@ -139,7 +142,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
// DashboardStats 仪表盘统计 // DashboardStats 仪表盘统计
type DashboardStats = usagestats.DashboardStats type DashboardStats = usagestats.DashboardStats
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
var stats DashboardStats var stats DashboardStats
today := timezone.Today() today := timezone.Today()
...@@ -260,7 +263,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -260,7 +263,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil return &stats, nil
} }
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
...@@ -287,7 +290,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, ...@@ -287,7 +290,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}, nil }, nil
} }
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
...@@ -296,7 +299,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID ...@@ -296,7 +299,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
...@@ -305,7 +308,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe ...@@ -305,7 +308,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
...@@ -314,7 +317,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco ...@@ -314,7 +317,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
...@@ -323,12 +326,12 @@ func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelN ...@@ -323,12 +326,12 @@ func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelN
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error { func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
} }
// GetAccountTodayStats 获取账号今日统计 // GetAccountTodayStats 获取账号今日统计
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
today := timezone.Today() today := timezone.Today()
var stats struct { var stats struct {
...@@ -358,7 +361,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ...@@ -358,7 +361,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
} }
// GetAccountWindowStats 获取账号时间窗口内的统计 // GetAccountWindowStats 获取账号时间窗口内的统计
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
var stats struct { var stats struct {
Requests int64 `gorm:"column:requests"` Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"` Tokens int64 `gorm:"column:tokens"`
...@@ -398,7 +401,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint ...@@ -398,7 +401,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date // GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) { func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
var results []ApiKeyUsageTrendPoint var results []ApiKeyUsageTrendPoint
// Choose date format based on granularity // Choose date format based on granularity
...@@ -442,7 +445,7 @@ func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, ...@@ -442,7 +445,7 @@ func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
} }
// GetUserUsageTrend returns usage trend data grouped by user and date // GetUserUsageTrend returns usage trend data grouped by user and date
func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) { func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
var results []UserUsageTrendPoint var results []UserUsageTrendPoint
// Choose date format based on granularity // Choose date format based on granularity
...@@ -491,7 +494,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e ...@@ -491,7 +494,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
type UserDashboardStats = usagestats.UserDashboardStats type UserDashboardStats = usagestats.UserDashboardStats
// GetUserDashboardStats 获取用户专属的仪表盘统计 // GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
var stats UserDashboardStats var stats UserDashboardStats
today := timezone.Today() today := timezone.Today()
...@@ -578,7 +581,7 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -578,7 +581,7 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
} }
// GetUserUsageTrendByUserID 获取指定用户的使用趋势 // GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) { func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
var results []TrendDataPoint var results []TrendDataPoint
var dateFormat string var dateFormat string
...@@ -612,7 +615,7 @@ func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user ...@@ -612,7 +615,7 @@ func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
} }
// GetUserModelStats 获取指定用户的模型统计 // GetUserModelStats 获取指定用户的模型统计
func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
...@@ -641,7 +644,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64 ...@@ -641,7 +644,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type UsageLogFilters = usagestats.UsageLogFilters type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
...@@ -692,7 +695,7 @@ type UsageStats = usagestats.UsageStats ...@@ -692,7 +695,7 @@ type UsageStats = usagestats.UsageStats
type BatchUserUsageStats = usagestats.BatchUserUsageStats type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users // GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
if len(userIDs) == 0 { if len(userIDs) == 0 {
return make(map[int64]*BatchUserUsageStats), nil return make(map[int64]*BatchUserUsageStats), nil
} }
...@@ -752,7 +755,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -752,7 +755,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys // GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) { func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return make(map[int64]*BatchApiKeyUsageStats), nil return make(map[int64]*BatchApiKeyUsageStats), nil
} }
...@@ -809,7 +812,7 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe ...@@ -809,7 +812,7 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
} }
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters // GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) { func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
var results []TrendDataPoint var results []TrendDataPoint
var dateFormat string var dateFormat string
...@@ -848,7 +851,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -848,7 +851,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
} }
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters // GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
...@@ -882,7 +885,7 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start ...@@ -882,7 +885,7 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
} }
// GetGlobalStats gets usage statistics for all users within a time range // GetGlobalStats gets usage statistics for all users within a time range
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
var stats struct { var stats struct {
TotalRequests int64 `gorm:"column:total_requests"` TotalRequests int64 `gorm:"column:total_requests"`
TotalInputTokens int64 `gorm:"column:total_input_tokens"` TotalInputTokens int64 `gorm:"column:total_input_tokens"`
...@@ -932,7 +935,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary ...@@ -932,7 +935,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range // GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) { func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1 daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
if daysCount <= 0 { if daysCount <= 0 {
daysCount = 30 daysCount = 30
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment