Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
eeaff85e
Commit
eeaff85e
authored
Dec 25, 2025
by
Forest
Browse files
refactor: 自定义业务错误
parent
f51ad2e1
Changes
60
Show whitespace changes
Inline
Side-by-side
backend/internal/infrastructure/errors/types.go
0 → 100644
View file @
eeaff85e
// nolint:mnd
package
errors
import
"net/http"
// BadRequest new BadRequest error that is mapped to a 400 response.
func
BadRequest
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
http
.
StatusBadRequest
,
reason
,
message
)
}
// IsBadRequest determines if err is an error which indicates a BadRequest error.
// It supports wrapped errors.
func
IsBadRequest
(
err
error
)
bool
{
return
Code
(
err
)
==
http
.
StatusBadRequest
}
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
func
TooManyRequests
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
http
.
StatusTooManyRequests
,
reason
,
message
)
}
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
// It supports wrapped errors.
func
IsTooManyRequests
(
err
error
)
bool
{
return
Code
(
err
)
==
http
.
StatusTooManyRequests
}
// Unauthorized new Unauthorized error that is mapped to a 401 response.
func
Unauthorized
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
http
.
StatusUnauthorized
,
reason
,
message
)
}
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
// It supports wrapped errors.
func
IsUnauthorized
(
err
error
)
bool
{
return
Code
(
err
)
==
http
.
StatusUnauthorized
}
// Forbidden new Forbidden error that is mapped to a 403 response.
func
Forbidden
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
http
.
StatusForbidden
,
reason
,
message
)
}
// IsForbidden determines if err is an error which indicates a Forbidden error.
// It supports wrapped errors.
func
IsForbidden
(
err
error
)
bool
{
return
Code
(
err
)
==
http
.
StatusForbidden
}
// NotFound new NotFound error that is mapped to a 404 response.
func
NotFound
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
http
.
StatusNotFound
,
reason
,
message
)
}
// IsNotFound determines if err is an error which indicates an NotFound error.
// It supports wrapped errors.
func
IsNotFound
(
err
error
)
bool
{
return
Code
(
err
)
==
http
.
StatusNotFound
}
// Conflict new Conflict error that is mapped to a 409 response.
func
Conflict
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
http
.
StatusConflict
,
reason
,
message
)
}
// IsConflict determines if err is an error which indicates a Conflict error.
// It supports wrapped errors.
func
IsConflict
(
err
error
)
bool
{
return
Code
(
err
)
==
http
.
StatusConflict
}
// InternalServer new InternalServer error that is mapped to a 500 response.
func
InternalServer
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
http
.
StatusInternalServerError
,
reason
,
message
)
}
// IsInternalServer determines if err is an error which indicates an Internal error.
// It supports wrapped errors.
func
IsInternalServer
(
err
error
)
bool
{
return
Code
(
err
)
==
http
.
StatusInternalServerError
}
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
func
ServiceUnavailable
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
http
.
StatusServiceUnavailable
,
reason
,
message
)
}
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
// It supports wrapped errors.
func
IsServiceUnavailable
(
err
error
)
bool
{
return
Code
(
err
)
==
http
.
StatusServiceUnavailable
}
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
func
GatewayTimeout
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
http
.
StatusGatewayTimeout
,
reason
,
message
)
}
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
// It supports wrapped errors.
func
IsGatewayTimeout
(
err
error
)
bool
{
return
Code
(
err
)
==
http
.
StatusGatewayTimeout
}
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
func
ClientClosed
(
reason
,
message
string
)
*
ApplicationError
{
return
New
(
499
,
reason
,
message
)
}
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
// It supports wrapped errors.
func
IsClientClosed
(
err
error
)
bool
{
return
Code
(
err
)
==
499
}
backend/internal/middleware/admin_auth.go
View file @
eeaff85e
...
...
@@ -3,9 +3,11 @@ package middleware
import
(
"context"
"crypto/subtle"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin"
)
...
...
@@ -96,7 +98,7 @@ func validateJWTForAdmin(
// 验证 JWT token
claims
,
err
:=
authService
.
ValidateToken
(
token
)
if
err
!=
nil
{
if
err
==
service
.
ErrTokenExpired
{
if
err
ors
.
Is
(
err
,
service
.
ErrTokenExpired
)
{
AbortWithError
(
c
,
401
,
"TOKEN_EXPIRED"
,
"Token has expired"
)
return
false
}
...
...
backend/internal/middleware/jwt_auth.go
View file @
eeaff85e
...
...
@@ -2,9 +2,11 @@ package middleware
import
(
"context"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin"
)
...
...
@@ -37,7 +39,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface {
// 验证token
claims
,
err
:=
authService
.
ValidateToken
(
tokenString
)
if
err
!=
nil
{
if
err
==
service
.
ErrTokenExpired
{
if
err
ors
.
Is
(
err
,
service
.
ErrTokenExpired
)
{
AbortWithError
(
c
,
401
,
"TOKEN_EXPIRED"
,
"Token has expired"
)
return
}
...
...
backend/internal/pkg/response/response.go
View file @
eeaff85e
...
...
@@ -4,6 +4,7 @@ import (
"math"
"net/http"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
)
...
...
@@ -11,6 +12,8 @@ import (
type
Response
struct
{
Code
int
`json:"code"`
Message
string
`json:"message"`
Reason
string
`json:"reason,omitempty"`
Metadata
map
[
string
]
string
`json:"metadata,omitempty"`
Data
any
`json:"data,omitempty"`
}
...
...
@@ -46,9 +49,34 @@ func Error(c *gin.Context, statusCode int, message string) {
c
.
JSON
(
statusCode
,
Response
{
Code
:
statusCode
,
Message
:
message
,
Reason
:
""
,
Metadata
:
nil
,
})
}
// ErrorWithDetails returns an error response compatible with the existing envelope while
// optionally providing structured error fields (reason/metadata).
func
ErrorWithDetails
(
c
*
gin
.
Context
,
statusCode
int
,
message
,
reason
string
,
metadata
map
[
string
]
string
)
{
c
.
JSON
(
statusCode
,
Response
{
Code
:
statusCode
,
Message
:
message
,
Reason
:
reason
,
Metadata
:
metadata
,
})
}
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
// It returns true if an error was written.
func
ErrorFrom
(
c
*
gin
.
Context
,
err
error
)
bool
{
if
err
==
nil
{
return
false
}
statusCode
,
status
:=
infraerrors
.
ToHTTP
(
err
)
ErrorWithDetails
(
c
,
statusCode
,
status
.
Message
,
status
.
Reason
,
status
.
Metadata
)
return
true
}
// BadRequest 返回400错误
func
BadRequest
(
c
*
gin
.
Context
,
message
string
)
{
Error
(
c
,
http
.
StatusBadRequest
,
message
)
...
...
backend/internal/pkg/response/response_test.go
0 → 100644
View file @
eeaff85e
//go:build unit
package
response
import
(
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestErrorWithDetails
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
statusCode
int
message
string
reason
string
metadata
map
[
string
]
string
want
Response
}{
{
name
:
"plain_error"
,
statusCode
:
http
.
StatusBadRequest
,
message
:
"invalid request"
,
want
:
Response
{
Code
:
http
.
StatusBadRequest
,
Message
:
"invalid request"
,
},
},
{
name
:
"structured_error"
,
statusCode
:
http
.
StatusForbidden
,
message
:
"no access"
,
reason
:
"FORBIDDEN"
,
metadata
:
map
[
string
]
string
{
"k"
:
"v"
},
want
:
Response
{
Code
:
http
.
StatusForbidden
,
Message
:
"no access"
,
Reason
:
"FORBIDDEN"
,
Metadata
:
map
[
string
]
string
{
"k"
:
"v"
},
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
ErrorWithDetails
(
c
,
tt
.
statusCode
,
tt
.
message
,
tt
.
reason
,
tt
.
metadata
)
require
.
Equal
(
t
,
tt
.
statusCode
,
w
.
Code
)
var
got
Response
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
got
))
require
.
Equal
(
t
,
tt
.
want
,
got
)
})
}
}
func
TestErrorFrom
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
err
error
wantWritten
bool
wantHTTPCode
int
wantBody
Response
}{
{
name
:
"nil_error"
,
err
:
nil
,
wantWritten
:
false
,
},
{
name
:
"application_error"
,
err
:
infraerrors
.
Forbidden
(
"FORBIDDEN"
,
"no access"
)
.
WithMetadata
(
map
[
string
]
string
{
"scope"
:
"admin"
}),
wantWritten
:
true
,
wantHTTPCode
:
http
.
StatusForbidden
,
wantBody
:
Response
{
Code
:
http
.
StatusForbidden
,
Message
:
"no access"
,
Reason
:
"FORBIDDEN"
,
Metadata
:
map
[
string
]
string
{
"scope"
:
"admin"
},
},
},
{
name
:
"bad_request_error"
,
err
:
infraerrors
.
BadRequest
(
"INVALID_REQUEST"
,
"invalid request"
),
wantWritten
:
true
,
wantHTTPCode
:
http
.
StatusBadRequest
,
wantBody
:
Response
{
Code
:
http
.
StatusBadRequest
,
Message
:
"invalid request"
,
Reason
:
"INVALID_REQUEST"
,
},
},
{
name
:
"unauthorized_error"
,
err
:
infraerrors
.
Unauthorized
(
"UNAUTHORIZED"
,
"unauthorized"
),
wantWritten
:
true
,
wantHTTPCode
:
http
.
StatusUnauthorized
,
wantBody
:
Response
{
Code
:
http
.
StatusUnauthorized
,
Message
:
"unauthorized"
,
Reason
:
"UNAUTHORIZED"
,
},
},
{
name
:
"not_found_error"
,
err
:
infraerrors
.
NotFound
(
"NOT_FOUND"
,
"not found"
),
wantWritten
:
true
,
wantHTTPCode
:
http
.
StatusNotFound
,
wantBody
:
Response
{
Code
:
http
.
StatusNotFound
,
Message
:
"not found"
,
Reason
:
"NOT_FOUND"
,
},
},
{
name
:
"conflict_error"
,
err
:
infraerrors
.
Conflict
(
"CONFLICT"
,
"conflict"
),
wantWritten
:
true
,
wantHTTPCode
:
http
.
StatusConflict
,
wantBody
:
Response
{
Code
:
http
.
StatusConflict
,
Message
:
"conflict"
,
Reason
:
"CONFLICT"
,
},
},
{
name
:
"unknown_error_defaults_to_500"
,
err
:
errors
.
New
(
"boom"
),
wantWritten
:
true
,
wantHTTPCode
:
http
.
StatusInternalServerError
,
wantBody
:
Response
{
Code
:
http
.
StatusInternalServerError
,
Message
:
"boom"
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
written
:=
ErrorFrom
(
c
,
tt
.
err
)
require
.
Equal
(
t
,
tt
.
wantWritten
,
written
)
if
!
tt
.
wantWritten
{
require
.
Equal
(
t
,
200
,
w
.
Code
)
require
.
Empty
(
t
,
w
.
Body
.
String
())
return
}
require
.
Equal
(
t
,
tt
.
wantHTTPCode
,
w
.
Code
)
var
got
Response
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
got
))
require
.
Equal
(
t
,
tt
.
wantBody
,
got
)
})
}
}
backend/internal/repository/account_repo.go
View file @
eeaff85e
...
...
@@ -13,23 +13,23 @@ import (
"gorm.io/gorm/clause"
)
type
A
ccountRepository
struct
{
type
a
ccountRepository
struct
{
db
*
gorm
.
DB
}
func
NewAccountRepository
(
db
*
gorm
.
DB
)
*
AccountRepository
{
return
&
A
ccountRepository
{
db
:
db
}
func
NewAccountRepository
(
db
*
gorm
.
DB
)
service
.
AccountRepository
{
return
&
a
ccountRepository
{
db
:
db
}
}
func
(
r
*
A
ccountRepository
)
Create
(
ctx
context
.
Context
,
account
*
model
.
Account
)
error
{
func
(
r
*
a
ccountRepository
)
Create
(
ctx
context
.
Context
,
account
*
model
.
Account
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
account
)
.
Error
}
func
(
r
*
A
ccountRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
{
func
(
r
*
a
ccountRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Account
,
error
)
{
var
account
model
.
Account
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"Proxy"
)
.
Preload
(
"AccountGroups.Group"
)
.
First
(
&
account
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrAccountNotFound
,
nil
)
}
// 填充 GroupIDs 和 Groups 虚拟字段
account
.
GroupIDs
=
make
([]
int64
,
0
,
len
(
account
.
AccountGroups
))
...
...
@@ -43,7 +43,7 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou
return
&
account
,
nil
}
func
(
r
*
A
ccountRepository
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
model
.
Account
,
error
)
{
func
(
r
*
a
ccountRepository
)
GetByCRSAccountID
(
ctx
context
.
Context
,
crsAccountID
string
)
(
*
model
.
Account
,
error
)
{
if
crsAccountID
==
""
{
return
nil
,
nil
}
...
...
@@ -59,11 +59,11 @@ func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return
&
account
,
nil
}
func
(
r
*
A
ccountRepository
)
Update
(
ctx
context
.
Context
,
account
*
model
.
Account
)
error
{
func
(
r
*
a
ccountRepository
)
Update
(
ctx
context
.
Context
,
account
*
model
.
Account
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
account
)
.
Error
}
func
(
r
*
A
ccountRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
a
ccountRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
// 先删除账号与分组的绑定关系
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ?"
,
id
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
;
err
!=
nil
{
return
err
...
...
@@ -72,12 +72,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
Account
{},
id
)
.
Error
}
func
(
r
*
A
ccountRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
a
ccountRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
,
""
)
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func
(
r
*
A
ccountRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
model
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
a
ccountRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
model
.
Account
,
*
pagination
.
PaginationResult
,
error
)
{
var
accounts
[]
model
.
Account
var
total
int64
...
...
@@ -131,7 +131,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
},
nil
}
func
(
r
*
A
ccountRepository
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
model
.
Account
,
error
)
{
func
(
r
*
a
ccountRepository
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.account_id = accounts.id"
)
.
...
...
@@ -142,7 +142,7 @@ func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]m
return
accounts
,
err
}
func
(
r
*
A
ccountRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Account
,
error
)
{
func
(
r
*
a
ccountRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
...
...
@@ -152,12 +152,12 @@ func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, er
return
accounts
,
err
}
func
(
r
*
A
ccountRepository
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
a
ccountRepository
)
UpdateLastUsed
(
ctx
context
.
Context
,
id
int64
)
error
{
now
:=
time
.
Now
()
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"last_used_at"
,
now
)
.
Error
}
func
(
r
*
A
ccountRepository
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
func
(
r
*
a
ccountRepository
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
any
{
"status"
:
model
.
StatusError
,
...
...
@@ -165,7 +165,7 @@ func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg str
})
.
Error
}
func
(
r
*
A
ccountRepository
)
AddToGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
,
priority
int
)
error
{
func
(
r
*
a
ccountRepository
)
AddToGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
,
priority
int
)
error
{
ag
:=
&
model
.
AccountGroup
{
AccountID
:
accountID
,
GroupID
:
groupID
,
...
...
@@ -174,12 +174,12 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
ag
)
.
Error
}
func
(
r
*
A
ccountRepository
)
RemoveFromGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
)
error
{
func
(
r
*
a
ccountRepository
)
RemoveFromGroup
(
ctx
context
.
Context
,
accountID
,
groupID
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ? AND group_id = ?"
,
accountID
,
groupID
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
}
func
(
r
*
A
ccountRepository
)
GetGroups
(
ctx
context
.
Context
,
accountID
int64
)
([]
model
.
Group
,
error
)
{
func
(
r
*
a
ccountRepository
)
GetGroups
(
ctx
context
.
Context
,
accountID
int64
)
([]
model
.
Group
,
error
)
{
var
groups
[]
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Joins
(
"JOIN account_groups ON account_groups.group_id = groups.id"
)
.
...
...
@@ -188,7 +188,7 @@ func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]m
return
groups
,
err
}
func
(
r
*
A
ccountRepository
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Account
,
error
)
{
func
(
r
*
a
ccountRepository
)
ListByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"platform = ? AND status = ?"
,
platform
,
model
.
StatusActive
)
.
...
...
@@ -198,7 +198,7 @@ func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string)
return
accounts
,
err
}
func
(
r
*
A
ccountRepository
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
func
(
r
*
a
ccountRepository
)
BindGroups
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
// 删除现有绑定
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ?"
,
accountID
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
;
err
!=
nil
{
return
err
...
...
@@ -221,7 +221,7 @@ func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, gro
}
// ListSchedulable 获取所有可调度的账号
func
(
r
*
A
ccountRepository
)
ListSchedulable
(
ctx
context
.
Context
)
([]
model
.
Account
,
error
)
{
func
(
r
*
a
ccountRepository
)
ListSchedulable
(
ctx
context
.
Context
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
...
...
@@ -235,7 +235,7 @@ func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Accoun
}
// ListSchedulableByGroupID 按组获取可调度的账号
func
(
r
*
A
ccountRepository
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
model
.
Account
,
error
)
{
func
(
r
*
a
ccountRepository
)
ListSchedulableByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
...
...
@@ -251,7 +251,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
}
// ListSchedulableByPlatform 按平台获取可调度的账号
func
(
r
*
A
ccountRepository
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Account
,
error
)
{
func
(
r
*
a
ccountRepository
)
ListSchedulableByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
...
...
@@ -266,7 +266,7 @@ func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platf
}
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func
(
r
*
A
ccountRepository
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
model
.
Account
,
error
)
{
func
(
r
*
a
ccountRepository
)
ListSchedulableByGroupIDAndPlatform
(
ctx
context
.
Context
,
groupID
int64
,
platform
string
)
([]
model
.
Account
,
error
)
{
var
accounts
[]
model
.
Account
now
:=
time
.
Now
()
err
:=
r
.
db
.
WithContext
(
ctx
)
.
...
...
@@ -283,7 +283,7 @@ func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
}
// SetRateLimited 标记账号为限流状态(429)
func
(
r
*
A
ccountRepository
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
func
(
r
*
a
ccountRepository
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
now
:=
time
.
Now
()
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
any
{
...
...
@@ -293,13 +293,13 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
}
// SetOverloaded 标记账号为过载状态(529)
func
(
r
*
A
ccountRepository
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
func
(
r
*
a
ccountRepository
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"overload_until"
,
until
)
.
Error
}
// ClearRateLimit 清除账号的限流状态
func
(
r
*
A
ccountRepository
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
a
ccountRepository
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Updates
(
map
[
string
]
any
{
"rate_limited_at"
:
nil
,
...
...
@@ -309,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func
(
r
*
A
ccountRepository
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
func
(
r
*
a
ccountRepository
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
updates
:=
map
[
string
]
any
{
"session_window_status"
:
status
,
}
...
...
@@ -323,14 +323,14 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
}
// SetSchedulable 设置账号的调度开关
func
(
r
*
A
ccountRepository
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
func
(
r
*
a
ccountRepository
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"id = ?"
,
id
)
.
Update
(
"schedulable"
,
schedulable
)
.
Error
}
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func
(
r
*
A
ccountRepository
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
func
(
r
*
a
ccountRepository
)
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
{
if
len
(
updates
)
==
0
{
return
nil
}
...
...
@@ -358,7 +358,7 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func
(
r
*
A
ccountRepository
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
service
.
AccountBulkUpdate
)
(
int64
,
error
)
{
func
(
r
*
a
ccountRepository
)
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
service
.
AccountBulkUpdate
)
(
int64
,
error
)
{
if
len
(
ids
)
==
0
{
return
0
,
nil
}
...
...
backend/internal/repository/account_repo_integration_test.go
View file @
eeaff85e
...
...
@@ -18,13 +18,13 @@ type AccountRepoSuite struct {
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
repo
*
A
ccountRepository
repo
*
a
ccountRepository
}
func
(
s
*
AccountRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
repo
=
NewAccountRepository
(
s
.
db
)
s
.
repo
=
NewAccountRepository
(
s
.
db
)
.
(
*
accountRepository
)
}
func
TestAccountRepoSuite
(
t
*
testing
.
T
)
{
...
...
@@ -167,7 +167,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s
.
Run
(
tt
.
name
,
func
()
{
// 每个 case 重新获取隔离资源
db
:=
testTx
(
s
.
T
())
repo
:=
NewAccountRepository
(
db
)
repo
:=
NewAccountRepository
(
db
)
.
(
*
accountRepository
)
ctx
:=
context
.
Background
()
tt
.
setup
(
db
)
...
...
backend/internal/repository/api_key_repo.go
View file @
eeaff85e
...
...
@@ -2,51 +2,55 @@ package repository
import
(
"context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type
A
piKeyRepository
struct
{
type
a
piKeyRepository
struct
{
db
*
gorm
.
DB
}
func
NewApiKeyRepository
(
db
*
gorm
.
DB
)
*
ApiKeyRepository
{
return
&
A
piKeyRepository
{
db
:
db
}
func
NewApiKeyRepository
(
db
*
gorm
.
DB
)
service
.
ApiKeyRepository
{
return
&
a
piKeyRepository
{
db
:
db
}
}
func
(
r
*
ApiKeyRepository
)
Create
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
key
)
.
Error
func
(
r
*
apiKeyRepository
)
Create
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
key
)
.
Error
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrApiKeyExists
)
}
func
(
r
*
A
piKeyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
ApiKey
,
error
)
{
func
(
r
*
a
piKeyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
ApiKey
,
error
)
{
var
key
model
.
ApiKey
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
First
(
&
key
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrApiKeyNotFound
,
nil
)
}
return
&
key
,
nil
}
func
(
r
*
A
piKeyRepository
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
ApiKey
,
error
)
{
func
(
r
*
a
piKeyRepository
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
ApiKey
,
error
)
{
var
apiKey
model
.
ApiKey
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Preload
(
"User"
)
.
Preload
(
"Group"
)
.
Where
(
"key = ?"
,
key
)
.
First
(
&
apiKey
)
.
Error
if
err
!=
nil
{
return
nil
,
err
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrApiKeyNotFound
,
nil
)
}
return
&
apiKey
,
nil
}
func
(
r
*
A
piKeyRepository
)
Update
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
func
(
r
*
a
piKeyRepository
)
Update
(
ctx
context
.
Context
,
key
*
model
.
ApiKey
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Model
(
key
)
.
Select
(
"name"
,
"group_id"
,
"status"
,
"updated_at"
)
.
Updates
(
key
)
.
Error
}
func
(
r
*
A
piKeyRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
a
piKeyRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
ApiKey
{},
id
)
.
Error
}
func
(
r
*
A
piKeyRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
a
piKeyRepository
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
var
keys
[]
model
.
ApiKey
var
total
int64
...
...
@@ -73,19 +77,19 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
},
nil
}
func
(
r
*
A
piKeyRepository
)
CountByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
func
(
r
*
a
piKeyRepository
)
CountByUserID
(
ctx
context
.
Context
,
userID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"user_id = ?"
,
userID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
func
(
r
*
A
piKeyRepository
)
ExistsByKey
(
ctx
context
.
Context
,
key
string
)
(
bool
,
error
)
{
func
(
r
*
a
piKeyRepository
)
ExistsByKey
(
ctx
context
.
Context
,
key
string
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"key = ?"
,
key
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
}
func
(
r
*
A
piKeyRepository
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
a
piKeyRepository
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
ApiKey
,
*
pagination
.
PaginationResult
,
error
)
{
var
keys
[]
model
.
ApiKey
var
total
int64
...
...
@@ -113,7 +117,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func
(
r
*
A
piKeyRepository
)
SearchApiKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
model
.
ApiKey
,
error
)
{
func
(
r
*
a
piKeyRepository
)
SearchApiKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
model
.
ApiKey
,
error
)
{
var
keys
[]
model
.
ApiKey
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
...
...
@@ -135,7 +139,7 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func
(
r
*
A
piKeyRepository
)
ClearGroupIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
func
(
r
*
a
piKeyRepository
)
ClearGroupIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Update
(
"group_id"
,
nil
)
...
...
@@ -143,7 +147,7 @@ func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
}
// CountByGroupID 获取分组的 API Key 数量
func
(
r
*
A
piKeyRepository
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
func
(
r
*
a
piKeyRepository
)
CountByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
...
...
backend/internal/repository/api_key_repo_integration_test.go
View file @
eeaff85e
...
...
@@ -16,13 +16,13 @@ type ApiKeyRepoSuite struct {
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
repo
*
A
piKeyRepository
repo
*
a
piKeyRepository
}
func
(
s
*
ApiKeyRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
repo
=
NewApiKeyRepository
(
s
.
db
)
s
.
repo
=
NewApiKeyRepository
(
s
.
db
)
.
(
*
apiKeyRepository
)
}
func
TestApiKeyRepoSuite
(
t
*
testing
.
T
)
{
...
...
backend/internal/repository/error_translate.go
0 → 100644
View file @
eeaff85e
package
repository
import
(
"errors"
"strings"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"gorm.io/gorm"
)
func
translatePersistenceError
(
err
error
,
notFound
,
conflict
*
infraerrors
.
ApplicationError
)
error
{
if
err
==
nil
{
return
nil
}
if
notFound
!=
nil
&&
errors
.
Is
(
err
,
gorm
.
ErrRecordNotFound
)
{
return
notFound
.
WithCause
(
err
)
}
if
conflict
!=
nil
&&
isUniqueConstraintViolation
(
err
)
{
return
conflict
.
WithCause
(
err
)
}
return
err
}
func
isUniqueConstraintViolation
(
err
error
)
bool
{
if
err
==
nil
{
return
false
}
if
errors
.
Is
(
err
,
gorm
.
ErrDuplicatedKey
)
{
return
true
}
msg
:=
strings
.
ToLower
(
err
.
Error
())
return
strings
.
Contains
(
msg
,
"duplicate key"
)
||
strings
.
Contains
(
msg
,
"unique constraint"
)
||
strings
.
Contains
(
msg
,
"duplicate entry"
)
}
backend/internal/repository/group_repo.go
View file @
eeaff85e
...
...
@@ -2,47 +2,52 @@ package repository
import
(
"context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type
G
roupRepository
struct
{
type
g
roupRepository
struct
{
db
*
gorm
.
DB
}
func
NewGroupRepository
(
db
*
gorm
.
DB
)
*
GroupRepository
{
return
&
G
roupRepository
{
db
:
db
}
func
NewGroupRepository
(
db
*
gorm
.
DB
)
service
.
GroupRepository
{
return
&
g
roupRepository
{
db
:
db
}
}
func
(
r
*
GroupRepository
)
Create
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
group
)
.
Error
func
(
r
*
groupRepository
)
Create
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Create
(
group
)
.
Error
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrGroupExists
)
}
func
(
r
*
G
roupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Group
,
error
)
{
func
(
r
*
g
roupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Group
,
error
)
{
var
group
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
group
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
return
&
group
,
nil
}
func
(
r
*
G
roupRepository
)
Update
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
func
(
r
*
g
roupRepository
)
Update
(
ctx
context
.
Context
,
group
*
model
.
Group
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
group
)
.
Error
}
func
(
r
*
G
roupRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
g
roupRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
Group
{},
id
)
.
Error
}
func
(
r
*
G
roupRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
g
roupRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
nil
)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func
(
r
*
G
roupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
model
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
g
roupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
model
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
var
groups
[]
model
.
Group
var
total
int64
...
...
@@ -86,7 +91,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination
},
nil
}
func
(
r
*
G
roupRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Group
,
error
)
{
func
(
r
*
g
roupRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Group
,
error
)
{
var
groups
[]
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
if
err
!=
nil
{
...
...
@@ -100,7 +105,7 @@ func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error)
return
groups
,
nil
}
func
(
r
*
G
roupRepository
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Group
,
error
)
{
func
(
r
*
g
roupRepository
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
model
.
Group
,
error
)
{
var
groups
[]
model
.
Group
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ? AND platform = ?"
,
model
.
StatusActive
,
platform
)
.
Order
(
"id ASC"
)
.
Find
(
&
groups
)
.
Error
if
err
!=
nil
{
...
...
@@ -114,25 +119,80 @@ func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform str
return
groups
,
nil
}
func
(
r
*
G
roupRepository
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
func
(
r
*
g
roupRepository
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Group
{})
.
Where
(
"name = ?"
,
name
)
.
Count
(
&
count
)
.
Error
return
count
>
0
,
err
}
func
(
r
*
G
roupRepository
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
func
(
r
*
g
roupRepository
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
AccountGroup
{})
.
Where
(
"group_id = ?"
,
groupID
)
.
Count
(
&
count
)
.
Error
return
count
,
err
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func
(
r
*
G
roupRepository
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
func
(
r
*
g
roupRepository
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"group_id = ?"
,
groupID
)
.
Delete
(
&
model
.
AccountGroup
{})
return
result
.
RowsAffected
,
result
.
Error
}
// DB 返回底层数据库连接,用于事务处理
func
(
r
*
GroupRepository
)
DB
()
*
gorm
.
DB
{
return
r
.
db
func
(
r
*
groupRepository
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
{
group
,
err
:=
r
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
var
affectedUserIDs
[]
int64
if
group
.
IsSubscriptionType
()
{
var
subscriptions
[]
model
.
UserSubscription
if
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UserSubscription
{})
.
Where
(
"group_id = ?"
,
id
)
.
Select
(
"user_id"
)
.
Find
(
&
subscriptions
)
.
Error
;
err
!=
nil
{
return
nil
,
err
}
for
_
,
sub
:=
range
subscriptions
{
affectedUserIDs
=
append
(
affectedUserIDs
,
sub
.
UserID
)
}
}
err
=
r
.
db
.
WithContext
(
ctx
)
.
Transaction
(
func
(
tx
*
gorm
.
DB
)
error
{
// 1. 删除订阅类型分组的订阅记录
if
group
.
IsSubscriptionType
()
{
if
err
:=
tx
.
Where
(
"group_id = ?"
,
id
)
.
Delete
(
&
model
.
UserSubscription
{})
.
Error
;
err
!=
nil
{
return
err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if
err
:=
tx
.
Model
(
&
model
.
ApiKey
{})
.
Where
(
"group_id = ?"
,
id
)
.
Update
(
"group_id"
,
nil
)
.
Error
;
err
!=
nil
{
return
err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if
err
:=
tx
.
Model
(
&
model
.
User
{})
.
Where
(
"? = ANY(allowed_groups)"
,
id
)
.
Update
(
"allowed_groups"
,
gorm
.
Expr
(
"array_remove(allowed_groups, ?)"
,
id
))
.
Error
;
err
!=
nil
{
return
err
}
// 4. 删除 account_groups 中间表的数据
if
err
:=
tx
.
Where
(
"group_id = ?"
,
id
)
.
Delete
(
&
model
.
AccountGroup
{})
.
Error
;
err
!=
nil
{
return
err
}
// 5. 删除分组本身(带锁,避免并发写)
if
err
:=
tx
.
Clauses
(
clause
.
Locking
{
Strength
:
"UPDATE"
})
.
Delete
(
&
model
.
Group
{},
id
)
.
Error
;
err
!=
nil
{
return
err
}
return
nil
})
if
err
!=
nil
{
return
nil
,
err
}
return
affectedUserIDs
,
nil
}
backend/internal/repository/group_repo_integration_test.go
View file @
eeaff85e
...
...
@@ -16,13 +16,13 @@ type GroupRepoSuite struct {
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
repo
*
G
roupRepository
repo
*
g
roupRepository
}
func
(
s
*
GroupRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
repo
=
NewGroupRepository
(
s
.
db
)
s
.
repo
=
NewGroupRepository
(
s
.
db
)
.
(
*
groupRepository
)
}
func
TestGroupRepoSuite
(
t
*
testing
.
T
)
{
...
...
@@ -234,11 +234,3 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count
,
_
:=
s
.
repo
.
GetAccountCount
(
s
.
ctx
,
g
.
ID
)
s
.
Require
()
.
Zero
(
count
)
}
// --- DB ---
func
(
s
*
GroupRepoSuite
)
TestDB
()
{
db
:=
s
.
repo
.
DB
()
s
.
Require
()
.
NotNil
(
db
,
"DB should return non-nil"
)
s
.
Require
()
.
Equal
(
s
.
db
,
db
,
"DB should return the underlying gorm.DB"
)
}
backend/internal/repository/proxy_repo.go
View file @
eeaff85e
...
...
@@ -2,47 +2,50 @@ package repository
import
(
"context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type
P
roxyRepository
struct
{
type
p
roxyRepository
struct
{
db
*
gorm
.
DB
}
func
NewProxyRepository
(
db
*
gorm
.
DB
)
*
ProxyRepository
{
return
&
P
roxyRepository
{
db
:
db
}
func
NewProxyRepository
(
db
*
gorm
.
DB
)
service
.
ProxyRepository
{
return
&
p
roxyRepository
{
db
:
db
}
}
func
(
r
*
P
roxyRepository
)
Create
(
ctx
context
.
Context
,
proxy
*
model
.
Proxy
)
error
{
func
(
r
*
p
roxyRepository
)
Create
(
ctx
context
.
Context
,
proxy
*
model
.
Proxy
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
proxy
)
.
Error
}
func
(
r
*
P
roxyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Proxy
,
error
)
{
func
(
r
*
p
roxyRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
Proxy
,
error
)
{
var
proxy
model
.
Proxy
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
proxy
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrProxyNotFound
,
nil
)
}
return
&
proxy
,
nil
}
func
(
r
*
P
roxyRepository
)
Update
(
ctx
context
.
Context
,
proxy
*
model
.
Proxy
)
error
{
func
(
r
*
p
roxyRepository
)
Update
(
ctx
context
.
Context
,
proxy
*
model
.
Proxy
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
proxy
)
.
Error
}
func
(
r
*
P
roxyRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
p
roxyRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
Proxy
{},
id
)
.
Error
}
func
(
r
*
P
roxyRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
Proxy
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
p
roxyRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
Proxy
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
)
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func
(
r
*
P
roxyRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
protocol
,
status
,
search
string
)
([]
model
.
Proxy
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
p
roxyRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
protocol
,
status
,
search
string
)
([]
model
.
Proxy
,
*
pagination
.
PaginationResult
,
error
)
{
var
proxies
[]
model
.
Proxy
var
total
int64
...
...
@@ -81,14 +84,14 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination
},
nil
}
func
(
r
*
P
roxyRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Proxy
,
error
)
{
func
(
r
*
p
roxyRepository
)
ListActive
(
ctx
context
.
Context
)
([]
model
.
Proxy
,
error
)
{
var
proxies
[]
model
.
Proxy
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
Find
(
&
proxies
)
.
Error
return
proxies
,
err
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func
(
r
*
P
roxyRepository
)
ExistsByHostPortAuth
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
{
func
(
r
*
p
roxyRepository
)
ExistsByHostPortAuth
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Proxy
{})
.
Where
(
"host = ? AND port = ? AND username = ? AND password = ?"
,
host
,
port
,
username
,
password
)
.
...
...
@@ -100,7 +103,7 @@ func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
}
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func
(
r
*
P
roxyRepository
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
{
func
(
r
*
p
roxyRepository
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
{
var
count
int64
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
Account
{})
.
Where
(
"proxy_id = ?"
,
proxyID
)
.
...
...
@@ -109,7 +112,7 @@ func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func
(
r
*
P
roxyRepository
)
GetAccountCountsForProxies
(
ctx
context
.
Context
)
(
map
[
int64
]
int64
,
error
)
{
func
(
r
*
p
roxyRepository
)
GetAccountCountsForProxies
(
ctx
context
.
Context
)
(
map
[
int64
]
int64
,
error
)
{
type
result
struct
{
ProxyID
int64
`gorm:"column:proxy_id"`
Count
int64
`gorm:"column:count"`
...
...
@@ -133,7 +136,7 @@ func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func
(
r
*
P
roxyRepository
)
ListActiveWithAccountCount
(
ctx
context
.
Context
)
([]
model
.
ProxyWithAccountCount
,
error
)
{
func
(
r
*
p
roxyRepository
)
ListActiveWithAccountCount
(
ctx
context
.
Context
)
([]
model
.
ProxyWithAccountCount
,
error
)
{
var
proxies
[]
model
.
Proxy
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"status = ?"
,
model
.
StatusActive
)
.
...
...
backend/internal/repository/proxy_repo_integration_test.go
View file @
eeaff85e
...
...
@@ -17,13 +17,13 @@ type ProxyRepoSuite struct {
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
repo
*
P
roxyRepository
repo
*
p
roxyRepository
}
func
(
s
*
ProxyRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
repo
=
NewProxyRepository
(
s
.
db
)
s
.
repo
=
NewProxyRepository
(
s
.
db
)
.
(
*
proxyRepository
)
}
func
TestProxyRepoSuite
(
t
*
testing
.
T
)
{
...
...
backend/internal/repository/redeem_code_repo.go
View file @
eeaff85e
...
...
@@ -2,57 +2,60 @@ package repository
import
(
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm"
)
type
R
edeemCodeRepository
struct
{
type
r
edeemCodeRepository
struct
{
db
*
gorm
.
DB
}
func
NewRedeemCodeRepository
(
db
*
gorm
.
DB
)
*
RedeemCodeRepository
{
return
&
R
edeemCodeRepository
{
db
:
db
}
func
NewRedeemCodeRepository
(
db
*
gorm
.
DB
)
service
.
RedeemCodeRepository
{
return
&
r
edeemCodeRepository
{
db
:
db
}
}
func
(
r
*
R
edeemCodeRepository
)
Create
(
ctx
context
.
Context
,
code
*
model
.
RedeemCode
)
error
{
func
(
r
*
r
edeemCodeRepository
)
Create
(
ctx
context
.
Context
,
code
*
model
.
RedeemCode
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
code
)
.
Error
}
func
(
r
*
R
edeemCodeRepository
)
CreateBatch
(
ctx
context
.
Context
,
codes
[]
model
.
RedeemCode
)
error
{
func
(
r
*
r
edeemCodeRepository
)
CreateBatch
(
ctx
context
.
Context
,
codes
[]
model
.
RedeemCode
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
&
codes
)
.
Error
}
func
(
r
*
R
edeemCodeRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
RedeemCode
,
error
)
{
func
(
r
*
r
edeemCodeRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
RedeemCode
,
error
)
{
var
code
model
.
RedeemCode
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
code
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrRedeemCodeNotFound
,
nil
)
}
return
&
code
,
nil
}
func
(
r
*
R
edeemCodeRepository
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
model
.
RedeemCode
,
error
)
{
func
(
r
*
r
edeemCodeRepository
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
model
.
RedeemCode
,
error
)
{
var
redeemCode
model
.
RedeemCode
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"code = ?"
,
code
)
.
First
(
&
redeemCode
)
.
Error
if
err
!=
nil
{
return
nil
,
err
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrRedeemCodeNotFound
,
nil
)
}
return
&
redeemCode
,
nil
}
func
(
r
*
R
edeemCodeRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
r
edeemCodeRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
RedeemCode
{},
id
)
.
Error
}
func
(
r
*
R
edeemCodeRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
r
edeemCodeRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
model
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
)
}
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func
(
r
*
R
edeemCodeRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
codeType
,
status
,
search
string
)
([]
model
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
r
edeemCodeRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
codeType
,
status
,
search
string
)
([]
model
.
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
var
codes
[]
model
.
RedeemCode
var
total
int64
...
...
@@ -91,11 +94,11 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
},
nil
}
func
(
r
*
R
edeemCodeRepository
)
Update
(
ctx
context
.
Context
,
code
*
model
.
RedeemCode
)
error
{
func
(
r
*
r
edeemCodeRepository
)
Update
(
ctx
context
.
Context
,
code
*
model
.
RedeemCode
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Save
(
code
)
.
Error
}
func
(
r
*
R
edeemCodeRepository
)
Use
(
ctx
context
.
Context
,
id
,
userID
int64
)
error
{
func
(
r
*
r
edeemCodeRepository
)
Use
(
ctx
context
.
Context
,
id
,
userID
int64
)
error
{
now
:=
time
.
Now
()
result
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
RedeemCode
{})
.
Where
(
"id = ? AND status = ?"
,
id
,
model
.
StatusUnused
)
.
...
...
@@ -108,13 +111,13 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return
result
.
Error
}
if
result
.
RowsAffected
==
0
{
return
gorm
.
ErrRecordNotFound
// 兑换码不存在或已被使用
return
service
.
ErrRedeemCodeUsed
.
WithCause
(
gorm
.
ErrRecordNotFound
)
}
return
nil
}
// ListByUser returns all redeem codes used by a specific user
func
(
r
*
R
edeemCodeRepository
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
model
.
RedeemCode
,
error
)
{
func
(
r
*
r
edeemCodeRepository
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
limit
int
)
([]
model
.
RedeemCode
,
error
)
{
var
codes
[]
model
.
RedeemCode
if
limit
<=
0
{
limit
=
10
...
...
backend/internal/repository/redeem_code_repo_integration_test.go
View file @
eeaff85e
...
...
@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
...
...
@@ -17,13 +18,13 @@ type RedeemCodeRepoSuite struct {
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
repo
*
R
edeemCodeRepository
repo
*
r
edeemCodeRepository
}
func
(
s
*
RedeemCodeRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
repo
=
NewRedeemCodeRepository
(
s
.
db
)
s
.
repo
=
NewRedeemCodeRepository
(
s
.
db
)
.
(
*
redeemCodeRepository
)
}
func
TestRedeemCodeRepoSuite
(
t
*
testing
.
T
)
{
...
...
@@ -195,7 +196,7 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
// Second use should fail
err
=
s
.
repo
.
Use
(
s
.
ctx
,
code
.
ID
,
user
.
ID
)
s
.
Require
()
.
Error
(
err
,
"Use expected error on second call"
)
s
.
Require
()
.
ErrorIs
(
err
,
gorm
.
ErrRecordNotFoun
d
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrRedeemCodeUse
d
)
}
func
(
s
*
RedeemCodeRepoSuite
)
TestUse_AlreadyUsed
()
{
...
...
@@ -204,7 +205,7 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
err
:=
s
.
repo
.
Use
(
s
.
ctx
,
code
.
ID
,
user
.
ID
)
s
.
Require
()
.
Error
(
err
,
"expected error for already used code"
)
s
.
Require
()
.
ErrorIs
(
err
,
gorm
.
ErrRecordNotFoun
d
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrRedeemCodeUse
d
)
}
// --- ListByUser ---
...
...
@@ -298,7 +299,7 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s
.
Require
()
.
NoError
(
s
.
repo
.
Use
(
s
.
ctx
,
codeB
.
ID
,
user
.
ID
),
"Use"
)
err
=
s
.
repo
.
Use
(
s
.
ctx
,
codeB
.
ID
,
user
.
ID
)
s
.
Require
()
.
Error
(
err
,
"Use expected error on second call"
)
s
.
Require
()
.
ErrorIs
(
err
,
gorm
.
ErrRecordNotFoun
d
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrRedeemCodeUse
d
)
codeA
,
err
:=
s
.
repo
.
GetByCode
(
s
.
ctx
,
"CODEA"
)
s
.
Require
()
.
NoError
(
err
,
"GetByCode"
)
...
...
backend/internal/repository/repository.go
View file @
eeaff85e
package
repository
import
"github.com/Wei-Shaw/sub2api/internal/service"
// Repositories 所有仓库的集合
type
Repositories
struct
{
User
*
UserRepository
ApiKey
*
ApiKeyRepository
Group
*
GroupRepository
Account
*
AccountRepository
Proxy
*
ProxyRepository
RedeemCode
*
RedeemCodeRepository
UsageLog
*
UsageLogRepository
Setting
*
SettingRepository
UserSubscription
*
UserSubscriptionRepository
User
service
.
UserRepository
ApiKey
service
.
ApiKeyRepository
Group
service
.
GroupRepository
Account
service
.
AccountRepository
Proxy
service
.
ProxyRepository
RedeemCode
service
.
RedeemCodeRepository
UsageLog
service
.
UsageLogRepository
Setting
service
.
SettingRepository
UserSubscription
service
.
UserSubscriptionRepository
}
backend/internal/repository/setting_repo.go
View file @
eeaff85e
...
...
@@ -2,35 +2,38 @@ package repository
import
(
"context"
"github.com/Wei-Shaw/sub2api/internal/model"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// SettingRepository 系统设置数据访问层
type
S
ettingRepository
struct
{
type
s
ettingRepository
struct
{
db
*
gorm
.
DB
}
// NewSettingRepository 创建系统设置仓库实例
func
NewSettingRepository
(
db
*
gorm
.
DB
)
*
SettingRepository
{
return
&
S
ettingRepository
{
db
:
db
}
func
NewSettingRepository
(
db
*
gorm
.
DB
)
service
.
SettingRepository
{
return
&
s
ettingRepository
{
db
:
db
}
}
// Get 根据Key获取设置值
func
(
r
*
S
ettingRepository
)
Get
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
Setting
,
error
)
{
func
(
r
*
s
ettingRepository
)
Get
(
ctx
context
.
Context
,
key
string
)
(
*
model
.
Setting
,
error
)
{
var
setting
model
.
Setting
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"key = ?"
,
key
)
.
First
(
&
setting
)
.
Error
if
err
!=
nil
{
return
nil
,
err
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrSettingNotFound
,
nil
)
}
return
&
setting
,
nil
}
// GetValue 获取设置值字符串
func
(
r
*
S
ettingRepository
)
GetValue
(
ctx
context
.
Context
,
key
string
)
(
string
,
error
)
{
func
(
r
*
s
ettingRepository
)
GetValue
(
ctx
context
.
Context
,
key
string
)
(
string
,
error
)
{
setting
,
err
:=
r
.
Get
(
ctx
,
key
)
if
err
!=
nil
{
return
""
,
err
...
...
@@ -39,7 +42,7 @@ func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, e
}
// Set 设置值(存在则更新,不存在则创建)
func
(
r
*
S
ettingRepository
)
Set
(
ctx
context
.
Context
,
key
,
value
string
)
error
{
func
(
r
*
s
ettingRepository
)
Set
(
ctx
context
.
Context
,
key
,
value
string
)
error
{
setting
:=
&
model
.
Setting
{
Key
:
key
,
Value
:
value
,
...
...
@@ -53,7 +56,7 @@ func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
}
// GetMultiple 批量获取设置
func
(
r
*
S
ettingRepository
)
GetMultiple
(
ctx
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
func
(
r
*
s
ettingRepository
)
GetMultiple
(
ctx
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
var
settings
[]
model
.
Setting
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"key IN ?"
,
keys
)
.
Find
(
&
settings
)
.
Error
if
err
!=
nil
{
...
...
@@ -68,7 +71,7 @@ func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map
}
// SetMultiple 批量设置值
func
(
r
*
S
ettingRepository
)
SetMultiple
(
ctx
context
.
Context
,
settings
map
[
string
]
string
)
error
{
func
(
r
*
s
ettingRepository
)
SetMultiple
(
ctx
context
.
Context
,
settings
map
[
string
]
string
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Transaction
(
func
(
tx
*
gorm
.
DB
)
error
{
for
key
,
value
:=
range
settings
{
setting
:=
&
model
.
Setting
{
...
...
@@ -88,7 +91,7 @@ func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string
}
// GetAll 获取所有设置
func
(
r
*
S
ettingRepository
)
GetAll
(
ctx
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
func
(
r
*
s
ettingRepository
)
GetAll
(
ctx
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
var
settings
[]
model
.
Setting
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Find
(
&
settings
)
.
Error
if
err
!=
nil
{
...
...
@@ -103,6 +106,6 @@ func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, erro
}
// Delete 删除设置
func
(
r
*
S
ettingRepository
)
Delete
(
ctx
context
.
Context
,
key
string
)
error
{
func
(
r
*
s
ettingRepository
)
Delete
(
ctx
context
.
Context
,
key
string
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"key = ?"
,
key
)
.
Delete
(
&
model
.
Setting
{})
.
Error
}
backend/internal/repository/setting_repo_integration_test.go
View file @
eeaff85e
...
...
@@ -6,6 +6,7 @@ import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
...
...
@@ -14,13 +15,13 @@ type SettingRepoSuite struct {
suite
.
Suite
ctx
context
.
Context
db
*
gorm
.
DB
repo
*
S
ettingRepository
repo
*
s
ettingRepository
}
func
(
s
*
SettingRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
db
=
testTx
(
s
.
T
())
s
.
repo
=
NewSettingRepository
(
s
.
db
)
s
.
repo
=
NewSettingRepository
(
s
.
db
)
.
(
*
settingRepository
)
}
func
TestSettingRepoSuite
(
t
*
testing
.
T
)
{
...
...
@@ -45,7 +46,7 @@ func (s *SettingRepoSuite) TestSet_Upsert() {
func
(
s
*
SettingRepoSuite
)
TestGetValue_Missing
()
{
_
,
err
:=
s
.
repo
.
GetValue
(
s
.
ctx
,
"nonexistent"
)
s
.
Require
()
.
Error
(
err
,
"expected error for missing key"
)
s
.
Require
()
.
ErrorIs
(
err
,
gorm
.
ErrRecord
NotFound
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrSetting
NotFound
)
}
func
(
s
*
SettingRepoSuite
)
TestSetMultiple_AndGetMultiple
()
{
...
...
@@ -86,7 +87,7 @@ func (s *SettingRepoSuite) TestDelete() {
s
.
Require
()
.
NoError
(
s
.
repo
.
Delete
(
s
.
ctx
,
"todelete"
),
"Delete"
)
_
,
err
:=
s
.
repo
.
GetValue
(
s
.
ctx
,
"todelete"
)
s
.
Require
()
.
Error
(
err
,
"expected missing key error after Delete"
)
s
.
Require
()
.
ErrorIs
(
err
,
gorm
.
ErrRecord
NotFound
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrSetting
NotFound
)
}
func
(
s
*
SettingRepoSuite
)
TestDelete_Idempotent
()
{
...
...
backend/internal/repository/usage_log_repo.go
View file @
eeaff85e
...
...
@@ -2,25 +2,28 @@ package repository
import
(
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"time"
"gorm.io/gorm"
)
type
U
sageLogRepository
struct
{
type
u
sageLogRepository
struct
{
db
*
gorm
.
DB
}
func
NewUsageLogRepository
(
db
*
gorm
.
DB
)
*
UsageLogRepository
{
return
&
U
sageLogRepository
{
db
:
db
}
func
NewUsageLogRepository
(
db
*
gorm
.
DB
)
service
.
UsageLogRepository
{
return
&
u
sageLogRepository
{
db
:
db
}
}
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
func
(
r
*
U
sageLogRepository
)
getPerformanceStats
(
ctx
context
.
Context
,
userID
int64
)
(
rpm
,
tpm
int64
)
{
func
(
r
*
u
sageLogRepository
)
getPerformanceStats
(
ctx
context
.
Context
,
userID
int64
)
(
rpm
,
tpm
int64
)
{
fiveMinutesAgo
:=
time
.
Now
()
.
Add
(
-
5
*
time
.
Minute
)
var
perfStats
struct
{
RequestCount
int64
`gorm:"column:request_count"`
...
...
@@ -43,20 +46,20 @@ func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int
return
perfStats
.
RequestCount
/
5
,
perfStats
.
TokenCount
/
5
}
func
(
r
*
U
sageLogRepository
)
Create
(
ctx
context
.
Context
,
log
*
model
.
UsageLog
)
error
{
func
(
r
*
u
sageLogRepository
)
Create
(
ctx
context
.
Context
,
log
*
model
.
UsageLog
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Create
(
log
)
.
Error
}
func
(
r
*
U
sageLogRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
UsageLog
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
UsageLog
,
error
)
{
var
log
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
First
(
&
log
,
id
)
.
Error
if
err
!=
nil
{
return
nil
,
err
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrUsageLogNotFound
,
nil
)
}
return
&
log
,
nil
}
func
(
r
*
U
sageLogRepository
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
u
sageLogRepository
)
ListByUser
(
ctx
context
.
Context
,
userID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
var
total
int64
...
...
@@ -83,7 +86,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
},
nil
}
func
(
r
*
U
sageLogRepository
)
ListByApiKey
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
u
sageLogRepository
)
ListByApiKey
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
var
total
int64
...
...
@@ -120,7 +123,7 @@ type UserStats struct {
CacheReadTokens
int64
`json:"cache_read_tokens"`
}
func
(
r
*
U
sageLogRepository
)
GetUserStats
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
UserStats
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetUserStats
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
UserStats
,
error
)
{
var
stats
UserStats
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
Select
(
`
...
...
@@ -139,7 +142,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
// DashboardStats 仪表盘统计
type
DashboardStats
=
usagestats
.
DashboardStats
func
(
r
*
U
sageLogRepository
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
DashboardStats
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
DashboardStats
,
error
)
{
var
stats
DashboardStats
today
:=
timezone
.
Today
()
...
...
@@ -260,7 +263,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return
&
stats
,
nil
}
func
(
r
*
U
sageLogRepository
)
ListByAccount
(
ctx
context
.
Context
,
accountID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
u
sageLogRepository
)
ListByAccount
(
ctx
context
.
Context
,
accountID
int64
,
params
pagination
.
PaginationParams
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
var
total
int64
...
...
@@ -287,7 +290,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
},
nil
}
func
(
r
*
U
sageLogRepository
)
ListByUserAndTimeRange
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
u
sageLogRepository
)
ListByUserAndTimeRange
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"user_id = ? AND created_at >= ? AND created_at < ?"
,
userID
,
startTime
,
endTime
)
.
...
...
@@ -296,7 +299,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
return
logs
,
nil
,
err
}
func
(
r
*
U
sageLogRepository
)
ListByApiKeyAndTimeRange
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
u
sageLogRepository
)
ListByApiKeyAndTimeRange
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"api_key_id = ? AND created_at >= ? AND created_at < ?"
,
apiKeyID
,
startTime
,
endTime
)
.
...
...
@@ -305,7 +308,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
return
logs
,
nil
,
err
}
func
(
r
*
U
sageLogRepository
)
ListByAccountAndTimeRange
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
u
sageLogRepository
)
ListByAccountAndTimeRange
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"account_id = ? AND created_at >= ? AND created_at < ?"
,
accountID
,
startTime
,
endTime
)
.
...
...
@@ -314,7 +317,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
return
logs
,
nil
,
err
}
func
(
r
*
U
sageLogRepository
)
ListByModelAndTimeRange
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
u
sageLogRepository
)
ListByModelAndTimeRange
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Where
(
"model = ? AND created_at >= ? AND created_at < ?"
,
modelName
,
startTime
,
endTime
)
.
...
...
@@ -323,12 +326,12 @@ func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelN
return
logs
,
nil
,
err
}
func
(
r
*
U
sageLogRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
r
*
u
sageLogRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
r
.
db
.
WithContext
(
ctx
)
.
Delete
(
&
model
.
UsageLog
{},
id
)
.
Error
}
// GetAccountTodayStats 获取账号今日统计
func
(
r
*
U
sageLogRepository
)
GetAccountTodayStats
(
ctx
context
.
Context
,
accountID
int64
)
(
*
usagestats
.
AccountStats
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetAccountTodayStats
(
ctx
context
.
Context
,
accountID
int64
)
(
*
usagestats
.
AccountStats
,
error
)
{
today
:=
timezone
.
Today
()
var
stats
struct
{
...
...
@@ -358,7 +361,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
}
// GetAccountWindowStats 获取账号时间窗口内的统计
func
(
r
*
U
sageLogRepository
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
var
stats
struct
{
Requests
int64
`gorm:"column:requests"`
Tokens
int64
`gorm:"column:tokens"`
...
...
@@ -398,7 +401,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
type
ApiKeyUsageTrendPoint
=
usagestats
.
ApiKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func
(
r
*
U
sageLogRepository
)
GetApiKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
ApiKeyUsageTrendPoint
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetApiKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
ApiKeyUsageTrendPoint
,
error
)
{
var
results
[]
ApiKeyUsageTrendPoint
// Choose date format based on granularity
...
...
@@ -442,7 +445,7 @@ func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
}
// GetUserUsageTrend returns usage trend data grouped by user and date
func
(
r
*
U
sageLogRepository
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
UserUsageTrendPoint
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
UserUsageTrendPoint
,
error
)
{
var
results
[]
UserUsageTrendPoint
// Choose date format based on granularity
...
...
@@ -491,7 +494,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
type
UserDashboardStats
=
usagestats
.
UserDashboardStats
// GetUserDashboardStats 获取用户专属的仪表盘统计
func
(
r
*
U
sageLogRepository
)
GetUserDashboardStats
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserDashboardStats
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetUserDashboardStats
(
ctx
context
.
Context
,
userID
int64
)
(
*
UserDashboardStats
,
error
)
{
var
stats
UserDashboardStats
today
:=
timezone
.
Today
()
...
...
@@ -578,7 +581,7 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
}
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func
(
r
*
U
sageLogRepository
)
GetUserUsageTrendByUserID
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
,
granularity
string
)
([]
TrendDataPoint
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetUserUsageTrendByUserID
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
,
granularity
string
)
([]
TrendDataPoint
,
error
)
{
var
results
[]
TrendDataPoint
var
dateFormat
string
...
...
@@ -612,7 +615,7 @@ func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
}
// GetUserModelStats 获取指定用户的模型统计
func
(
r
*
U
sageLogRepository
)
GetUserModelStats
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
ModelStat
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetUserModelStats
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
ModelStat
,
error
)
{
var
results
[]
ModelStat
err
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
...
...
@@ -641,7 +644,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type
UsageLogFilters
=
usagestats
.
UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func
(
r
*
U
sageLogRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
UsageLogFilters
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
u
sageLogRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
UsageLogFilters
)
([]
model
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
var
logs
[]
model
.
UsageLog
var
total
int64
...
...
@@ -692,7 +695,7 @@ type UsageStats = usagestats.UsageStats
type
BatchUserUsageStats
=
usagestats
.
BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func
(
r
*
U
sageLogRepository
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
BatchUserUsageStats
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
BatchUserUsageStats
,
error
)
{
if
len
(
userIDs
)
==
0
{
return
make
(
map
[
int64
]
*
BatchUserUsageStats
),
nil
}
...
...
@@ -752,7 +755,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
type
BatchApiKeyUsageStats
=
usagestats
.
BatchApiKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func
(
r
*
U
sageLogRepository
)
GetBatchApiKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
BatchApiKeyUsageStats
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetBatchApiKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
BatchApiKeyUsageStats
,
error
)
{
if
len
(
apiKeyIDs
)
==
0
{
return
make
(
map
[
int64
]
*
BatchApiKeyUsageStats
),
nil
}
...
...
@@ -809,7 +812,7 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func
(
r
*
U
sageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
([]
TrendDataPoint
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
([]
TrendDataPoint
,
error
)
{
var
results
[]
TrendDataPoint
var
dateFormat
string
...
...
@@ -848,7 +851,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func
(
r
*
U
sageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
int64
)
([]
ModelStat
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
int64
)
([]
ModelStat
,
error
)
{
var
results
[]
ModelStat
db
:=
r
.
db
.
WithContext
(
ctx
)
.
Model
(
&
model
.
UsageLog
{})
.
...
...
@@ -882,7 +885,7 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
}
// GetGlobalStats gets usage statistics for all users within a time range
func
(
r
*
U
sageLogRepository
)
GetGlobalStats
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
(
*
UsageStats
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetGlobalStats
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
(
*
UsageStats
,
error
)
{
var
stats
struct
{
TotalRequests
int64
`gorm:"column:total_requests"`
TotalInputTokens
int64
`gorm:"column:total_input_tokens"`
...
...
@@ -932,7 +935,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
type
AccountUsageStatsResponse
=
usagestats
.
AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func
(
r
*
U
sageLogRepository
)
GetAccountUsageStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
AccountUsageStatsResponse
,
error
)
{
func
(
r
*
u
sageLogRepository
)
GetAccountUsageStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
AccountUsageStatsResponse
,
error
)
{
daysCount
:=
int
(
endTime
.
Sub
(
startTime
)
.
Hours
()
/
24
)
+
1
if
daysCount
<=
0
{
daysCount
=
30
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment