Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
1a641392
Commit
1a641392
authored
Jan 10, 2026
by
cyhhao
Browse files
Merge up/main
parents
36b817d0
24d19a5f
Changes
174
Show whitespace changes
Inline
Side-by-side
backend/internal/handler/handler.go
View file @
1a641392
...
...
@@ -16,6 +16,7 @@ type AdminHandlers struct {
AntigravityOAuth
*
admin
.
AntigravityOAuthHandler
Proxy
*
admin
.
ProxyHandler
Redeem
*
admin
.
RedeemHandler
Promo
*
admin
.
PromoHandler
Setting
*
admin
.
SettingHandler
System
*
admin
.
SystemHandler
Subscription
*
admin
.
SubscriptionHandler
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
1a641392
...
...
@@ -8,9 +8,12 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -93,6 +96,24 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
// 获取客户端 IP
clientIP
:=
ip
.
GetClientIP
(
c
)
if
!
openai
.
IsCodexCLIRequest
(
userAgent
)
{
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
if
strings
.
TrimSpace
(
existingInstructions
)
==
""
{
if
instructions
:=
strings
.
TrimSpace
(
service
.
GetOpenCodeInstructions
());
instructions
!=
""
{
reqBody
[
"instructions"
]
=
instructions
// Re-serialize body
body
,
err
=
json
.
Marshal
(
reqBody
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to process request"
)
return
}
}
}
}
// Track if we've started streaming (for error handling)
streamStarted
:=
false
...
...
@@ -231,7 +252,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// Async record usage
go
func
(
result
*
service
.
OpenAIForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
)
{
go
func
(
result
*
service
.
OpenAIForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
,
cip
string
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
OpenAIRecordUsageInput
{
...
...
@@ -241,10 +262,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
Account
:
usedAccount
,
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
cip
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
,
userAgent
)
}(
result
,
account
,
userAgent
,
clientIP
)
return
}
}
...
...
backend/internal/handler/setting_handler.go
View file @
1a641392
...
...
@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
APIBaseURL
:
settings
.
APIBaseURL
,
ContactInfo
:
settings
.
ContactInfo
,
DocURL
:
settings
.
DocURL
,
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
Version
:
h
.
version
,
})
}
backend/internal/handler/wire.go
View file @
1a641392
...
...
@@ -19,6 +19,7 @@ func ProvideAdminHandlers(
antigravityOAuthHandler
*
admin
.
AntigravityOAuthHandler
,
proxyHandler
*
admin
.
ProxyHandler
,
redeemHandler
*
admin
.
RedeemHandler
,
promoHandler
*
admin
.
PromoHandler
,
settingHandler
*
admin
.
SettingHandler
,
systemHandler
*
admin
.
SystemHandler
,
subscriptionHandler
*
admin
.
SubscriptionHandler
,
...
...
@@ -36,6 +37,7 @@ func ProvideAdminHandlers(
AntigravityOAuth
:
antigravityOAuthHandler
,
Proxy
:
proxyHandler
,
Redeem
:
redeemHandler
,
Promo
:
promoHandler
,
Setting
:
settingHandler
,
System
:
systemHandler
,
Subscription
:
subscriptionHandler
,
...
...
@@ -105,6 +107,7 @@ var ProviderSet = wire.NewSet(
admin
.
NewAntigravityOAuthHandler
,
admin
.
NewProxyHandler
,
admin
.
NewRedeemHandler
,
admin
.
NewPromoHandler
,
admin
.
NewSettingHandler
,
ProvideSystemHandler
,
admin
.
NewSubscriptionHandler
,
...
...
backend/internal/middleware/rate_limiter.go
0 → 100644
View file @
1a641392
package
middleware
import
(
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RateLimiter Redis 速率限制器
type
RateLimiter
struct
{
redis
*
redis
.
Client
prefix
string
}
// NewRateLimiter 创建速率限制器实例
func
NewRateLimiter
(
redisClient
*
redis
.
Client
)
*
RateLimiter
{
return
&
RateLimiter
{
redis
:
redisClient
,
prefix
:
"rate_limit:"
,
}
}
// Limit 返回速率限制中间件
// key: 限制类型标识
// limit: 时间窗口内最大请求数
// window: 时间窗口
func
(
r
*
RateLimiter
)
Limit
(
key
string
,
limit
int
,
window
time
.
Duration
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
ip
:=
c
.
ClientIP
()
redisKey
:=
r
.
prefix
+
key
+
":"
+
ip
ctx
:=
c
.
Request
.
Context
()
// 使用 INCR 原子操作增加计数
count
,
err
:=
r
.
redis
.
Incr
(
ctx
,
redisKey
)
.
Result
()
if
err
!=
nil
{
// Redis 错误时放行,避免影响正常服务
c
.
Next
()
return
}
// 首次访问时设置过期时间
if
count
==
1
{
r
.
redis
.
Expire
(
ctx
,
redisKey
,
window
)
}
// 超过限制
if
count
>
int64
(
limit
)
{
c
.
AbortWithStatusJSON
(
http
.
StatusTooManyRequests
,
gin
.
H
{
"error"
:
"rate limit exceeded"
,
"message"
:
"Too many requests, please try again later"
,
})
return
}
c
.
Next
()
}
}
backend/internal/pkg/ctxkey/ctxkey.go
View file @
1a641392
...
...
@@ -9,4 +9,6 @@ const (
ForcePlatform
Key
=
"ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient
Key
=
"ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
Group
Key
=
"ctx_group"
)
backend/internal/pkg/ip/ip.go
0 → 100644
View file @
1a641392
// Package ip 提供客户端 IP 地址提取工具。
package
ip
import
(
"net"
"strings"
"github.com/gin-gonic/gin"
)
// GetClientIP 从 Gin Context 中提取客户端真实 IP 地址。
// 按以下优先级检查 Header:
// 1. CF-Connecting-IP (Cloudflare)
// 2. X-Real-IP (Nginx)
// 3. X-Forwarded-For (取第一个非私有 IP)
// 4. c.ClientIP() (Gin 内置方法)
func
GetClientIP
(
c
*
gin
.
Context
)
string
{
// 1. Cloudflare
if
ip
:=
c
.
GetHeader
(
"CF-Connecting-IP"
);
ip
!=
""
{
return
normalizeIP
(
ip
)
}
// 2. Nginx X-Real-IP
if
ip
:=
c
.
GetHeader
(
"X-Real-IP"
);
ip
!=
""
{
return
normalizeIP
(
ip
)
}
// 3. X-Forwarded-For (多个 IP 时取第一个公网 IP)
if
xff
:=
c
.
GetHeader
(
"X-Forwarded-For"
);
xff
!=
""
{
ips
:=
strings
.
Split
(
xff
,
","
)
for
_
,
ip
:=
range
ips
{
ip
=
strings
.
TrimSpace
(
ip
)
if
ip
!=
""
&&
!
isPrivateIP
(
ip
)
{
return
normalizeIP
(
ip
)
}
}
// 如果都是私有 IP,返回第一个
if
len
(
ips
)
>
0
{
return
normalizeIP
(
strings
.
TrimSpace
(
ips
[
0
]))
}
}
// 4. Gin 内置方法
return
normalizeIP
(
c
.
ClientIP
())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。
func
normalizeIP
(
ip
string
)
string
{
ip
=
strings
.
TrimSpace
(
ip
)
// 移除端口号(如 "192.168.1.1:8080" -> "192.168.1.1")
if
host
,
_
,
err
:=
net
.
SplitHostPort
(
ip
);
err
==
nil
{
return
host
}
return
ip
}
// isPrivateIP 检查 IP 是否为私有地址。
func
isPrivateIP
(
ipStr
string
)
bool
{
ip
:=
net
.
ParseIP
(
ipStr
)
if
ip
==
nil
{
return
false
}
// 私有 IP 范围
privateBlocks
:=
[]
string
{
"10.0.0.0/8"
,
"172.16.0.0/12"
,
"192.168.0.0/16"
,
"127.0.0.0/8"
,
"::1/128"
,
"fc00::/7"
,
}
for
_
,
block
:=
range
privateBlocks
{
_
,
cidr
,
err
:=
net
.
ParseCIDR
(
block
)
if
err
!=
nil
{
continue
}
if
cidr
.
Contains
(
ip
)
{
return
true
}
}
return
false
}
// MatchesPattern 检查 IP 是否匹配指定的模式(支持单个 IP 或 CIDR)。
// pattern 可以是:
// - 单个 IP: "192.168.1.100"
// - CIDR 范围: "192.168.1.0/24"
func
MatchesPattern
(
clientIP
,
pattern
string
)
bool
{
ip
:=
net
.
ParseIP
(
clientIP
)
if
ip
==
nil
{
return
false
}
// 尝试解析为 CIDR
if
strings
.
Contains
(
pattern
,
"/"
)
{
_
,
cidr
,
err
:=
net
.
ParseCIDR
(
pattern
)
if
err
!=
nil
{
return
false
}
return
cidr
.
Contains
(
ip
)
}
// 作为单个 IP 处理
patternIP
:=
net
.
ParseIP
(
pattern
)
if
patternIP
==
nil
{
return
false
}
return
ip
.
Equal
(
patternIP
)
}
// MatchesAnyPattern 检查 IP 是否匹配任意一个模式。
func
MatchesAnyPattern
(
clientIP
string
,
patterns
[]
string
)
bool
{
for
_
,
pattern
:=
range
patterns
{
if
MatchesPattern
(
clientIP
,
pattern
)
{
return
true
}
}
return
false
}
// CheckIPRestriction 检查 IP 是否被 API Key 的 IP 限制允许。
// 返回值:(是否允许, 拒绝原因)
// 逻辑:
// 1. 先检查黑名单,如果在黑名单中则直接拒绝
// 2. 如果白名单不为空,IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func
CheckIPRestriction
(
clientIP
string
,
whitelist
,
blacklist
[]
string
)
(
bool
,
string
)
{
// 规范化 IP
clientIP
=
normalizeIP
(
clientIP
)
if
clientIP
==
""
{
return
false
,
"access denied"
}
// 1. 检查黑名单
if
len
(
blacklist
)
>
0
&&
MatchesAnyPattern
(
clientIP
,
blacklist
)
{
return
false
,
"access denied"
}
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
if
len
(
whitelist
)
>
0
&&
!
MatchesAnyPattern
(
clientIP
,
whitelist
)
{
return
false
,
"access denied"
}
return
true
,
""
}
// ValidateIPPattern 验证 IP 或 CIDR 格式是否有效。
func
ValidateIPPattern
(
pattern
string
)
bool
{
if
strings
.
Contains
(
pattern
,
"/"
)
{
_
,
_
,
err
:=
net
.
ParseCIDR
(
pattern
)
return
err
==
nil
}
return
net
.
ParseIP
(
pattern
)
!=
nil
}
// ValidateIPPatterns 验证多个 IP 或 CIDR 格式。
// 返回无效的模式列表。
func
ValidateIPPatterns
(
patterns
[]
string
)
[]
string
{
var
invalid
[]
string
for
_
,
p
:=
range
patterns
{
if
!
ValidateIPPattern
(
p
)
{
invalid
=
append
(
invalid
,
p
)
}
}
return
invalid
}
backend/internal/repository/account_repo.go
View file @
1a641392
...
...
@@ -675,6 +675,40 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return
err
}
func
(
r
*
accountRepository
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
service
.
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
now
:=
time
.
Now
()
.
UTC
()
payload
:=
map
[
string
]
string
{
"rate_limited_at"
:
now
.
Format
(
time
.
RFC3339
),
"rate_limit_reset_at"
:
resetAt
.
UTC
()
.
Format
(
time
.
RFC3339
),
}
raw
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
err
}
path
:=
"{antigravity_quota_scopes,"
+
string
(
scope
)
+
"}"
client
:=
clientFromContext
(
ctx
,
r
.
client
)
result
,
err
:=
client
.
ExecContext
(
ctx
,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL"
,
path
,
raw
,
id
,
)
if
err
!=
nil
{
return
err
}
affected
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
err
}
if
affected
==
0
{
return
service
.
ErrAccountNotFound
}
return
nil
}
func
(
r
*
accountRepository
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
_
,
err
:=
r
.
client
.
Account
.
Update
()
.
Where
(
dbaccount
.
IDEQ
(
id
))
.
...
...
@@ -718,6 +752,27 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
return
err
}
func
(
r
*
accountRepository
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
result
,
err
:=
client
.
ExecContext
(
ctx
,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL"
,
id
,
)
if
err
!=
nil
{
return
err
}
affected
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
err
}
if
affected
==
0
{
return
service
.
ErrAccountNotFound
}
return
nil
}
func
(
r
*
accountRepository
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
builder
:=
r
.
client
.
Account
.
Update
()
.
Where
(
dbaccount
.
IDEQ
(
id
))
.
...
...
@@ -831,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args
=
append
(
args
,
*
updates
.
Status
)
idx
++
}
if
updates
.
Schedulable
!=
nil
{
setClauses
=
append
(
setClauses
,
"schedulable = $"
+
itoa
(
idx
))
args
=
append
(
args
,
*
updates
.
Schedulable
)
idx
++
}
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if
len
(
updates
.
Credentials
)
>
0
{
payload
,
err
:=
json
.
Marshal
(
updates
.
Credentials
)
...
...
backend/internal/repository/api_key_repo.go
View file @
1a641392
...
...
@@ -26,13 +26,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
}
func
(
r
*
apiKeyRepository
)
Create
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
created
,
er
r
:=
r
.
client
.
APIKey
.
Create
()
.
builde
r
:=
r
.
client
.
APIKey
.
Create
()
.
SetUserID
(
key
.
UserID
)
.
SetKey
(
key
.
Key
)
.
SetName
(
key
.
Name
)
.
SetStatus
(
key
.
Status
)
.
SetNillableGroupID
(
key
.
GroupID
)
.
Save
(
ctx
)
SetNillableGroupID
(
key
.
GroupID
)
if
len
(
key
.
IPWhitelist
)
>
0
{
builder
.
SetIPWhitelist
(
key
.
IPWhitelist
)
}
if
len
(
key
.
IPBlacklist
)
>
0
{
builder
.
SetIPBlacklist
(
key
.
IPBlacklist
)
}
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
==
nil
{
key
.
ID
=
created
.
ID
key
.
CreatedAt
=
created
.
CreatedAt
...
...
@@ -108,6 +116,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder
.
ClearGroupID
()
}
// IP 限制字段
if
len
(
key
.
IPWhitelist
)
>
0
{
builder
.
SetIPWhitelist
(
key
.
IPWhitelist
)
}
else
{
builder
.
ClearIPWhitelist
()
}
if
len
(
key
.
IPBlacklist
)
>
0
{
builder
.
SetIPBlacklist
(
key
.
IPBlacklist
)
}
else
{
builder
.
ClearIPBlacklist
()
}
affected
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
return
err
...
...
@@ -273,6 +293,8 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
Key
:
m
.
Key
,
Name
:
m
.
Name
,
Status
:
m
.
Status
,
IPWhitelist
:
m
.
IPWhitelist
,
IPBlacklist
:
m
.
IPBlacklist
,
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
GroupID
:
m
.
GroupID
,
...
...
@@ -317,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier
:
g
.
RateMultiplier
,
IsExclusive
:
g
.
IsExclusive
,
Status
:
g
.
Status
,
Hydrated
:
true
,
SubscriptionType
:
g
.
SubscriptionType
,
DailyLimitUSD
:
g
.
DailyLimitUsd
,
WeeklyLimitUSD
:
g
.
WeeklyLimitUsd
,
...
...
backend/internal/repository/group_repo.go
View file @
1a641392
...
...
@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
}
func
(
r
*
groupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
out
,
err
:=
r
.
GetByIDLite
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
out
.
ID
)
out
.
AccountCount
=
count
return
out
,
nil
}
func
(
r
*
groupRepository
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
// AccountCount is intentionally not loaded here; use GetByID when needed.
m
,
err
:=
r
.
client
.
Group
.
Query
()
.
Where
(
group
.
IDEQ
(
id
))
.
Only
(
ctx
)
...
...
@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
out
:=
groupEntityToService
(
m
)
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
out
.
ID
)
out
.
AccountCount
=
count
return
out
,
nil
return
groupEntityToService
(
m
),
nil
}
func
(
r
*
groupRepository
)
Update
(
ctx
context
.
Context
,
groupIn
*
service
.
Group
)
error
{
...
...
@@ -112,10 +120,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
}
func
(
r
*
groupRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
nil
)
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
,
nil
)
}
func
(
r
*
groupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
groupRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
q
:=
r
.
client
.
Group
.
Query
()
if
platform
!=
""
{
...
...
@@ -124,6 +132,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
if
status
!=
""
{
q
=
q
.
Where
(
group
.
StatusEQ
(
status
))
}
if
search
!=
""
{
q
=
q
.
Where
(
group
.
Or
(
group
.
NameContainsFold
(
search
),
group
.
DescriptionContainsFold
(
search
),
))
}
if
isExclusive
!=
nil
{
q
=
q
.
Where
(
group
.
IsExclusiveEQ
(
*
isExclusive
))
}
...
...
backend/internal/repository/group_repo_integration_test.go
View file @
1a641392
...
...
@@ -4,6 +4,8 @@ package repository
import
(
"context"
"database/sql"
"errors"
"testing"
dbent
"github.com/Wei-Shaw/sub2api/ent"
...
...
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
repo
*
groupRepository
}
type
forbidSQLExecutor
struct
{
called
bool
}
func
(
s
*
forbidSQLExecutor
)
ExecContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
sql
.
Result
,
error
)
{
s
.
called
=
true
return
nil
,
errors
.
New
(
"unexpected sql exec"
)
}
func
(
s
*
forbidSQLExecutor
)
QueryContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
*
sql
.
Rows
,
error
)
{
s
.
called
=
true
return
nil
,
errors
.
New
(
"unexpected sql query"
)
}
func
(
s
*
GroupRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
tx
:=
testEntTx
(
s
.
T
())
...
...
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrGroupNotFound
)
}
func
(
s
*
GroupRepoSuite
)
TestGetByIDLite_DoesNotUseAccountCount
()
{
group
:=
&
service
.
Group
{
Name
:
"lite-group"
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.0
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}
s
.
Require
()
.
NoError
(
s
.
repo
.
Create
(
s
.
ctx
,
group
))
spy
:=
&
forbidSQLExecutor
{}
repo
:=
newGroupRepositoryWithSQL
(
s
.
tx
.
Client
(),
spy
)
got
,
err
:=
repo
.
GetByIDLite
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
ID
)
s
.
Require
()
.
False
(
spy
.
called
,
"expected no direct sql executor usage"
)
}
func
(
s
*
GroupRepoSuite
)
TestUpdate
()
{
group
:=
&
service
.
Group
{
Name
:
"original"
,
...
...
@@ -131,6 +167,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
""
,
nil
,
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters base"
)
...
...
@@ -152,7 +189,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}))
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
nil
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformOpenAI
,
""
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
len
(
baseGroups
)
+
1
)
// Verify all groups are OpenAI platform
...
...
@@ -179,7 +216,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}))
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
service
.
StatusDisabled
,
nil
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
service
.
StatusDisabled
,
""
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
Equal
(
service
.
StatusDisabled
,
groups
[
0
]
.
Status
)
...
...
@@ -204,12 +241,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
}))
isExclusive
:=
true
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
&
isExclusive
)
groups
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
""
,
&
isExclusive
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
groups
,
1
)
s
.
Require
()
.
True
(
groups
[
0
]
.
IsExclusive
)
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_Search
()
{
newRepo
:=
func
()
(
*
groupRepository
,
context
.
Context
)
{
tx
:=
testEntTx
(
s
.
T
())
return
newGroupRepositoryWithSQL
(
tx
.
Client
(),
tx
),
context
.
Background
()
}
containsID
:=
func
(
groups
[]
service
.
Group
,
id
int64
)
bool
{
for
i
:=
range
groups
{
if
groups
[
i
]
.
ID
==
id
{
return
true
}
}
return
false
}
mustCreate
:=
func
(
repo
*
groupRepository
,
ctx
context
.
Context
,
g
*
service
.
Group
)
*
service
.
Group
{
s
.
Require
()
.
NoError
(
repo
.
Create
(
ctx
,
g
))
s
.
Require
()
.
NotZero
(
g
.
ID
)
return
g
}
newGroup
:=
func
(
name
string
)
*
service
.
Group
{
return
&
service
.
Group
{
Name
:
name
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.0
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}
}
s
.
Run
(
"search_name_should_match"
,
func
()
{
repo
,
ctx
:=
newRepo
()
target
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-name-target"
))
other
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-name-other"
))
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"name-target"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
target
.
ID
),
"expected target group to match by name"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
other
.
ID
),
"expected other group to be filtered out"
)
})
s
.
Run
(
"search_description_should_match"
,
func
()
{
repo
,
ctx
:=
newRepo
()
target
:=
newGroup
(
"it-group-search-desc-target"
)
target
.
Description
=
"something about desc-needle in here"
target
=
mustCreate
(
repo
,
ctx
,
target
)
other
:=
newGroup
(
"it-group-search-desc-other"
)
other
.
Description
=
"nothing to see here"
other
=
mustCreate
(
repo
,
ctx
,
other
)
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"desc-needle"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
target
.
ID
),
"expected target group to match by description"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
other
.
ID
),
"expected other group to be filtered out"
)
})
s
.
Run
(
"search_nonexistent_should_return_empty"
,
func
()
{
repo
,
ctx
:=
newRepo
()
_
=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-nonexistent-baseline"
))
search
:=
s
.
T
()
.
Name
()
+
"__no_such_group__"
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
search
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Empty
(
groups
)
})
s
.
Run
(
"search_should_be_case_insensitive"
,
func
()
{
repo
,
ctx
:=
newRepo
()
target
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"MiXeDCaSe-Needle"
))
other
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-case-other"
))
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"mixedcase-needle"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
target
.
ID
),
"expected case-insensitive match"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
other
.
ID
),
"expected other group to be filtered out"
)
})
s
.
Run
(
"search_should_escape_like_wildcards"
,
func
()
{
repo
,
ctx
:=
newRepo
()
percentTarget
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-100%-target"
))
percentOther
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-100X-other"
))
groups
,
_
,
err
:=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"100%"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
percentTarget
.
ID
),
"expected literal %% match"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
percentOther
.
ID
),
"expected %% not to act as wildcard"
)
underscoreTarget
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-ab_cd-target"
))
underscoreOther
:=
mustCreate
(
repo
,
ctx
,
newGroup
(
"it-group-search-abXcd-other"
))
groups
,
_
,
err
=
repo
.
ListWithFilters
(
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
50
},
""
,
""
,
"ab_cd"
,
nil
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
True
(
containsID
(
groups
,
underscoreTarget
.
ID
),
"expected literal _ match"
)
s
.
Require
()
.
False
(
containsID
(
groups
,
underscoreOther
.
ID
),
"expected _ not to act as wildcard"
)
})
}
func
(
s
*
GroupRepoSuite
)
TestListWithFilters_AccountCount
()
{
g1
:=
&
service
.
Group
{
Name
:
"g1"
,
...
...
@@ -244,7 +386,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
s
.
Require
()
.
NoError
(
err
)
isExclusive
:=
true
groups
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformAnthropic
,
service
.
StatusActive
,
&
isExclusive
)
groups
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
PlatformAnthropic
,
service
.
StatusActive
,
""
,
&
isExclusive
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
)
s
.
Require
()
.
Len
(
groups
,
1
)
...
...
backend/internal/repository/promo_code_repo.go
0 → 100644
View file @
1a641392
package
repository
import
(
"context"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type
promoCodeRepository
struct
{
client
*
dbent
.
Client
}
func
NewPromoCodeRepository
(
client
*
dbent
.
Client
)
service
.
PromoCodeRepository
{
return
&
promoCodeRepository
{
client
:
client
}
}
func
(
r
*
promoCodeRepository
)
Create
(
ctx
context
.
Context
,
code
*
service
.
PromoCode
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
builder
:=
client
.
PromoCode
.
Create
()
.
SetCode
(
code
.
Code
)
.
SetBonusAmount
(
code
.
BonusAmount
)
.
SetMaxUses
(
code
.
MaxUses
)
.
SetUsedCount
(
code
.
UsedCount
)
.
SetStatus
(
code
.
Status
)
.
SetNotes
(
code
.
Notes
)
if
code
.
ExpiresAt
!=
nil
{
builder
.
SetExpiresAt
(
*
code
.
ExpiresAt
)
}
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
return
err
}
code
.
ID
=
created
.
ID
code
.
CreatedAt
=
created
.
CreatedAt
code
.
UpdatedAt
=
created
.
UpdatedAt
return
nil
}
func
(
r
*
promoCodeRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
PromoCode
,
error
)
{
m
,
err
:=
r
.
client
.
PromoCode
.
Query
()
.
Where
(
promocode
.
IDEQ
(
id
))
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
service
.
ErrPromoCodeNotFound
}
return
nil
,
err
}
return
promoCodeEntityToService
(
m
),
nil
}
func
(
r
*
promoCodeRepository
)
GetByCode
(
ctx
context
.
Context
,
code
string
)
(
*
service
.
PromoCode
,
error
)
{
m
,
err
:=
r
.
client
.
PromoCode
.
Query
()
.
Where
(
promocode
.
CodeEqualFold
(
code
))
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
service
.
ErrPromoCodeNotFound
}
return
nil
,
err
}
return
promoCodeEntityToService
(
m
),
nil
}
func
(
r
*
promoCodeRepository
)
GetByCodeForUpdate
(
ctx
context
.
Context
,
code
string
)
(
*
service
.
PromoCode
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
m
,
err
:=
client
.
PromoCode
.
Query
()
.
Where
(
promocode
.
CodeEqualFold
(
code
))
.
ForUpdate
()
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
service
.
ErrPromoCodeNotFound
}
return
nil
,
err
}
return
promoCodeEntityToService
(
m
),
nil
}
func
(
r
*
promoCodeRepository
)
Update
(
ctx
context
.
Context
,
code
*
service
.
PromoCode
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
builder
:=
client
.
PromoCode
.
UpdateOneID
(
code
.
ID
)
.
SetCode
(
code
.
Code
)
.
SetBonusAmount
(
code
.
BonusAmount
)
.
SetMaxUses
(
code
.
MaxUses
)
.
SetUsedCount
(
code
.
UsedCount
)
.
SetStatus
(
code
.
Status
)
.
SetNotes
(
code
.
Notes
)
if
code
.
ExpiresAt
!=
nil
{
builder
.
SetExpiresAt
(
*
code
.
ExpiresAt
)
}
else
{
builder
.
ClearExpiresAt
()
}
updated
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
service
.
ErrPromoCodeNotFound
}
return
err
}
code
.
UpdatedAt
=
updated
.
UpdatedAt
return
nil
}
func
(
r
*
promoCodeRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
PromoCode
.
Delete
()
.
Where
(
promocode
.
IDEQ
(
id
))
.
Exec
(
ctx
)
return
err
}
func
(
r
*
promoCodeRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
PromoCode
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
)
}
func
(
r
*
promoCodeRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
search
string
)
([]
service
.
PromoCode
,
*
pagination
.
PaginationResult
,
error
)
{
q
:=
r
.
client
.
PromoCode
.
Query
()
if
status
!=
""
{
q
=
q
.
Where
(
promocode
.
StatusEQ
(
status
))
}
if
search
!=
""
{
q
=
q
.
Where
(
promocode
.
CodeContainsFold
(
search
))
}
total
,
err
:=
q
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
codes
,
err
:=
q
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
dbent
.
Desc
(
promocode
.
FieldID
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
outCodes
:=
promoCodeEntitiesToService
(
codes
)
return
outCodes
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
promoCodeRepository
)
CreateUsage
(
ctx
context
.
Context
,
usage
*
service
.
PromoCodeUsage
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
created
,
err
:=
client
.
PromoCodeUsage
.
Create
()
.
SetPromoCodeID
(
usage
.
PromoCodeID
)
.
SetUserID
(
usage
.
UserID
)
.
SetBonusAmount
(
usage
.
BonusAmount
)
.
SetUsedAt
(
usage
.
UsedAt
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
err
}
usage
.
ID
=
created
.
ID
return
nil
}
func
(
r
*
promoCodeRepository
)
GetUsageByPromoCodeAndUser
(
ctx
context
.
Context
,
promoCodeID
,
userID
int64
)
(
*
service
.
PromoCodeUsage
,
error
)
{
m
,
err
:=
r
.
client
.
PromoCodeUsage
.
Query
()
.
Where
(
promocodeusage
.
PromoCodeIDEQ
(
promoCodeID
),
promocodeusage
.
UserIDEQ
(
userID
),
)
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
nil
}
return
nil
,
err
}
return
promoCodeUsageEntityToService
(
m
),
nil
}
func
(
r
*
promoCodeRepository
)
ListUsagesByPromoCode
(
ctx
context
.
Context
,
promoCodeID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
PromoCodeUsage
,
*
pagination
.
PaginationResult
,
error
)
{
q
:=
r
.
client
.
PromoCodeUsage
.
Query
()
.
Where
(
promocodeusage
.
PromoCodeIDEQ
(
promoCodeID
))
total
,
err
:=
q
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
usages
,
err
:=
q
.
WithUser
()
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
dbent
.
Desc
(
promocodeusage
.
FieldID
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
outUsages
:=
promoCodeUsageEntitiesToService
(
usages
)
return
outUsages
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
(
r
*
promoCodeRepository
)
IncrementUsedCount
(
ctx
context
.
Context
,
id
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
PromoCode
.
UpdateOneID
(
id
)
.
AddUsedCount
(
1
)
.
Save
(
ctx
)
return
err
}
// Entity to Service conversions
func
promoCodeEntityToService
(
m
*
dbent
.
PromoCode
)
*
service
.
PromoCode
{
if
m
==
nil
{
return
nil
}
return
&
service
.
PromoCode
{
ID
:
m
.
ID
,
Code
:
m
.
Code
,
BonusAmount
:
m
.
BonusAmount
,
MaxUses
:
m
.
MaxUses
,
UsedCount
:
m
.
UsedCount
,
Status
:
m
.
Status
,
ExpiresAt
:
m
.
ExpiresAt
,
Notes
:
derefString
(
m
.
Notes
),
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
}
}
func
promoCodeEntitiesToService
(
models
[]
*
dbent
.
PromoCode
)
[]
service
.
PromoCode
{
out
:=
make
([]
service
.
PromoCode
,
0
,
len
(
models
))
for
i
:=
range
models
{
if
s
:=
promoCodeEntityToService
(
models
[
i
]);
s
!=
nil
{
out
=
append
(
out
,
*
s
)
}
}
return
out
}
func
promoCodeUsageEntityToService
(
m
*
dbent
.
PromoCodeUsage
)
*
service
.
PromoCodeUsage
{
if
m
==
nil
{
return
nil
}
out
:=
&
service
.
PromoCodeUsage
{
ID
:
m
.
ID
,
PromoCodeID
:
m
.
PromoCodeID
,
UserID
:
m
.
UserID
,
BonusAmount
:
m
.
BonusAmount
,
UsedAt
:
m
.
UsedAt
,
}
if
m
.
Edges
.
User
!=
nil
{
out
.
User
=
userEntityToService
(
m
.
Edges
.
User
)
}
return
out
}
func
promoCodeUsageEntitiesToService
(
models
[]
*
dbent
.
PromoCodeUsage
)
[]
service
.
PromoCodeUsage
{
out
:=
make
([]
service
.
PromoCodeUsage
,
0
,
len
(
models
))
for
i
:=
range
models
{
if
s
:=
promoCodeUsageEntityToService
(
models
[
i
]);
s
!=
nil
{
out
=
append
(
out
,
*
s
)
}
}
return
out
}
backend/internal/repository/usage_log_repo.go
View file @
1a641392
...
...
@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent,
ip_address,
image_count, image_size, created_at"
type
usageLogRepository
struct
{
client
*
dbent
.
Client
...
...
@@ -110,6 +110,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
created_at
...
...
@@ -119,7 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28
$20, $21, $22, $23, $24, $25, $26, $27, $28
, $29
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
...
...
@@ -130,6 +131,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration
:=
nullInt
(
log
.
DurationMs
)
firstToken
:=
nullInt
(
log
.
FirstTokenMs
)
userAgent
:=
nullString
(
log
.
UserAgent
)
ipAddress
:=
nullString
(
log
.
IPAddress
)
imageSize
:=
nullString
(
log
.
ImageSize
)
var
requestIDArg
any
...
...
@@ -163,6 +165,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration
,
firstToken
,
userAgent
,
ipAddress
,
log
.
ImageCount
,
imageSize
,
createdAt
,
...
...
@@ -1873,6 +1876,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
durationMs
sql
.
NullInt64
firstTokenMs
sql
.
NullInt64
userAgent
sql
.
NullString
ipAddress
sql
.
NullString
imageCount
int
imageSize
sql
.
NullString
createdAt
time
.
Time
...
...
@@ -1905,6 +1909,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&
durationMs
,
&
firstTokenMs
,
&
userAgent
,
&
ipAddress
,
&
imageCount
,
&
imageSize
,
&
createdAt
,
...
...
@@ -1959,6 +1964,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if
userAgent
.
Valid
{
log
.
UserAgent
=
&
userAgent
.
String
}
if
ipAddress
.
Valid
{
log
.
IPAddress
=
&
ipAddress
.
String
}
if
imageSize
.
Valid
{
log
.
ImageSize
=
&
imageSize
.
String
}
...
...
backend/internal/repository/wire.go
View file @
1a641392
...
...
@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet(
NewAccountRepository
,
NewProxyRepository
,
NewRedeemCodeRepository
,
NewPromoCodeRepository
,
NewUsageLogRepository
,
NewSettingRepository
,
NewUserSubscriptionRepository
,
...
...
backend/internal/server/api_contract_test.go
View file @
1a641392
...
...
@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One",
"group_id": null,
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One",
"group_id": null,
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -304,6 +308,10 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
...
...
@@ -390,7 +398,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo
:=
newStubSettingRepo
()
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
)
...
...
@@ -567,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
return
nil
,
service
.
ErrGroupNotFound
}
func
(
stubGroupRepo
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
return
nil
,
service
.
ErrGroupNotFound
}
func
(
stubGroupRepo
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -583,7 +595,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
stubGroupRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
stubGroupRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/http.go
View file @
1a641392
...
...
@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
// ProviderSet 提供服务器层的依赖
...
...
@@ -30,6 +31,7 @@ func ProvideRouter(
apiKeyAuth
middleware2
.
APIKeyAuthMiddleware
,
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
redisClient
*
redis
.
Client
,
)
*
gin
.
Engine
{
if
cfg
.
Server
.
Mode
==
"release"
{
gin
.
SetMode
(
gin
.
ReleaseMode
)
...
...
@@ -47,7 +49,7 @@ func ProvideRouter(
}
}
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
,
redisClient
)
}
// ProvideHTTPServer 提供 HTTP 服务器
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
1a641392
package
middleware
import
(
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
...
...
@@ -71,6 +74,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if
len
(
apiKey
.
IPWhitelist
)
>
0
||
len
(
apiKey
.
IPBlacklist
)
>
0
{
clientIP
:=
ip
.
GetClientIP
(
c
)
allowed
,
_
:=
ip
.
CheckIPRestriction
(
clientIP
,
apiKey
.
IPWhitelist
,
apiKey
.
IPBlacklist
)
if
!
allowed
{
AbortWithError
(
c
,
403
,
"ACCESS_DENIED"
,
"Access denied"
)
return
}
}
// 检查关联的用户
if
apiKey
.
User
==
nil
{
AbortWithError
(
c
,
401
,
"USER_NOT_FOUND"
,
"User associated with API key not found"
)
...
...
@@ -91,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
return
}
...
...
@@ -149,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
}
...
...
@@ -173,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
subscription
,
ok
:=
value
.
(
*
service
.
UserSubscription
)
return
subscription
,
ok
}
func
setGroupContext
(
c
*
gin
.
Context
,
group
*
service
.
Group
)
{
if
!
service
.
IsGroupContextValid
(
group
)
{
return
}
if
existing
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
&&
service
.
IsGroupContextValid
(
existing
)
{
return
}
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
Group
,
group
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
backend/internal/server/middleware/api_key_auth_google.go
View file @
1a641392
...
...
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
return
}
...
...
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
}
}
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
1a641392
...
...
@@ -9,6 +9,7 @@ import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
require
.
Equal
(
t
,
"INVALID_ARGUMENT"
,
resp
.
Error
.
Status
)
}
func
TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
99
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformGemini
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyService
:=
service
.
NewAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
},
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
},
)
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
:=
gin
.
New
()
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
1a641392
...
...
@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
...
...
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
ID
:
42
,
Name
:
"sub"
,
Status
:
service
.
StatusActive
,
Hydrated
:
true
,
SubscriptionType
:
service
.
SubscriptionTypeSubscription
,
DailyLimitUSD
:
&
limit
,
}
...
...
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func
TestAPIKeyAuthSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAPIKeyAuthOverwritesInvalidContextGroup
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
invalidGroup
:=
&
service
.
Group
{
ID
:
group
.
ID
,
Platform
:
group
.
Platform
,
Status
:
group
.
Status
,
}
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
||
!
groupFromCtx
.
Hydrated
||
groupFromCtx
==
invalidGroup
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
Group
,
invalidGroup
))
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment