Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a14dfb76
Commit
a14dfb76
authored
Feb 07, 2026
by
yangjianbo
Browse files
Merge branch 'dev-release'
parents
f3605ddc
2588fa6a
Changes
62
Hide whitespace changes
Inline
Side-by-side
backend/internal/pkg/openai/oauth_test.go
0 → 100644
View file @
a14dfb76
package
openai
import
(
"sync"
"testing"
"time"
)
func
TestSessionStore_Stop_Idempotent
(
t
*
testing
.
T
)
{
store
:=
NewSessionStore
()
store
.
Stop
()
store
.
Stop
()
select
{
case
<-
store
.
stopCh
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatal
(
"stopCh 未关闭"
)
}
}
func
TestSessionStore_Stop_Concurrent
(
t
*
testing
.
T
)
{
store
:=
NewSessionStore
()
var
wg
sync
.
WaitGroup
for
range
50
{
wg
.
Add
(
1
)
go
func
()
{
defer
wg
.
Done
()
store
.
Stop
()
}()
}
wg
.
Wait
()
select
{
case
<-
store
.
stopCh
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatal
(
"stopCh 未关闭"
)
}
}
backend/internal/pkg/tlsfingerprint/dialer.go
View file @
a14dfb76
...
...
@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
return
nil
,
fmt
.
Errorf
(
"apply TLS preset: %w"
,
err
)
}
if
err
:=
tlsConn
.
Handshake
(
);
err
!=
nil
{
if
err
:=
tlsConn
.
Handshake
Context
(
ctx
);
err
!=
nil
{
slog
.
Debug
(
"tls_fingerprint_socks5_handshake_failed"
,
"error"
,
err
)
_
=
conn
.
Close
()
return
nil
,
fmt
.
Errorf
(
"TLS handshake failed: %w"
,
err
)
...
...
backend/internal/repository/api_key_repo.go
View file @
a14dfb76
...
...
@@ -375,36 +375,19 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return
keys
,
nil
}
// IncrementQuotaUsed
atomically increments the quota_used field and returns the new value
// IncrementQuotaUsed
使用 Ent 原子递增 quota_used 字段并返回新值
func
(
r
*
apiKeyRepository
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m
,
err
:=
r
.
activeQuery
()
.
Where
(
apikey
.
IDEQ
(
id
))
.
Select
(
apikey
.
FieldQuotaUsed
)
.
Only
(
ctx
)
updated
,
err
:=
r
.
client
.
APIKey
.
UpdateOneID
(
id
)
.
Where
(
apikey
.
DeletedAtIsNil
())
.
AddQuotaUsed
(
amount
)
.
Save
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
0
,
service
.
ErrAPIKeyNotFound
}
return
0
,
err
}
newValue
:=
m
.
QuotaUsed
+
amount
// Update with new value
affected
,
err
:=
r
.
client
.
APIKey
.
Update
()
.
Where
(
apikey
.
IDEQ
(
id
),
apikey
.
DeletedAtIsNil
())
.
SetQuotaUsed
(
newValue
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
if
affected
==
0
{
return
0
,
service
.
ErrAPIKeyNotFound
}
return
newValue
,
nil
return
updated
.
QuotaUsed
,
nil
}
func
apiKeyEntityToService
(
m
*
dbent
.
APIKey
)
*
service
.
APIKey
{
...
...
backend/internal/repository/api_key_repo_integration_test.go
View file @
a14dfb76
...
...
@@ -4,11 +4,14 @@ package repository
import
(
"context"
"sync"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
...
...
@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
s
.
Require
()
.
NoError
(
s
.
repo
.
Create
(
s
.
ctx
,
k
),
"create api key"
)
return
k
}
// --- IncrementQuotaUsed ---
func
(
s
*
APIKeyRepoSuite
)
TestIncrementQuotaUsed_Basic
()
{
user
:=
s
.
mustCreateUser
(
"incr-basic@test.com"
)
key
:=
s
.
mustCreateApiKey
(
user
.
ID
,
"sk-incr-basic"
,
"Incr"
,
nil
)
newQuota
,
err
:=
s
.
repo
.
IncrementQuotaUsed
(
s
.
ctx
,
key
.
ID
,
1.5
)
s
.
Require
()
.
NoError
(
err
,
"IncrementQuotaUsed"
)
s
.
Require
()
.
Equal
(
1.5
,
newQuota
,
"第一次递增后应为 1.5"
)
newQuota
,
err
=
s
.
repo
.
IncrementQuotaUsed
(
s
.
ctx
,
key
.
ID
,
2.5
)
s
.
Require
()
.
NoError
(
err
,
"IncrementQuotaUsed second"
)
s
.
Require
()
.
Equal
(
4.0
,
newQuota
,
"第二次递增后应为 4.0"
)
}
func
(
s
*
APIKeyRepoSuite
)
TestIncrementQuotaUsed_NotFound
()
{
_
,
err
:=
s
.
repo
.
IncrementQuotaUsed
(
s
.
ctx
,
999999
,
1.0
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrAPIKeyNotFound
,
"不存在的 key 应返回 ErrAPIKeyNotFound"
)
}
func
(
s
*
APIKeyRepoSuite
)
TestIncrementQuotaUsed_DeletedKey
()
{
user
:=
s
.
mustCreateUser
(
"incr-deleted@test.com"
)
key
:=
s
.
mustCreateApiKey
(
user
.
ID
,
"sk-incr-del"
,
"Deleted"
,
nil
)
s
.
Require
()
.
NoError
(
s
.
repo
.
Delete
(
s
.
ctx
,
key
.
ID
),
"Delete"
)
_
,
err
:=
s
.
repo
.
IncrementQuotaUsed
(
s
.
ctx
,
key
.
ID
,
1.0
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrAPIKeyNotFound
,
"已删除的 key 应返回 ErrAPIKeyNotFound"
)
}
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
func
TestIncrementQuotaUsed_Concurrent
(
t
*
testing
.
T
)
{
client
:=
testEntClient
(
t
)
repo
:=
NewAPIKeyRepository
(
client
)
.
(
*
apiKeyRepository
)
ctx
:=
context
.
Background
()
// 创建测试用户和 API Key
u
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"concurrent-incr-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
)
+
"@test.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetStatus
(
service
.
StatusActive
)
.
SetRole
(
service
.
RoleUser
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
,
"create user"
)
k
:=
&
service
.
APIKey
{
UserID
:
u
.
ID
,
Key
:
"sk-concurrent-"
+
time
.
Now
()
.
Format
(
time
.
RFC3339Nano
),
Name
:
"Concurrent"
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
k
),
"create api key"
)
t
.
Cleanup
(
func
()
{
_
=
client
.
APIKey
.
DeleteOneID
(
k
.
ID
)
.
Exec
(
ctx
)
_
=
client
.
User
.
DeleteOneID
(
u
.
ID
)
.
Exec
(
ctx
)
})
// 10 个 goroutine 各递增 1.0,总计应为 10.0
const
goroutines
=
10
const
increment
=
1.0
var
wg
sync
.
WaitGroup
errs
:=
make
([]
error
,
goroutines
)
for
i
:=
0
;
i
<
goroutines
;
i
++
{
wg
.
Add
(
1
)
go
func
(
idx
int
)
{
defer
wg
.
Done
()
_
,
errs
[
idx
]
=
repo
.
IncrementQuotaUsed
(
ctx
,
k
.
ID
,
increment
)
}(
i
)
}
wg
.
Wait
()
for
i
,
e
:=
range
errs
{
require
.
NoError
(
t
,
e
,
"goroutine %d failed"
,
i
)
}
// 验证最终结果
got
,
err
:=
repo
.
GetByID
(
ctx
,
k
.
ID
)
require
.
NoError
(
t
,
err
,
"GetByID"
)
require
.
Equal
(
t
,
float64
(
goroutines
)
*
increment
,
got
.
QuotaUsed
,
"并发递增后总和应为 %v,实际为 %v"
,
float64
(
goroutines
)
*
increment
,
got
.
QuotaUsed
)
}
backend/internal/repository/billing_cache.go
View file @
a14dfb76
...
...
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"math/rand"
"strconv"
"time"
...
...
@@ -16,8 +17,15 @@ const (
billingBalanceKeyPrefix
=
"billing:balance:"
billingSubKeyPrefix
=
"billing:sub:"
billingCacheTTL
=
5
*
time
.
Minute
billingCacheJitter
=
30
*
time
.
Second
)
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
func
jitteredTTL
()
time
.
Duration
{
jitter
:=
time
.
Duration
(
rand
.
Int63n
(
int64
(
2
*
billingCacheJitter
)))
-
billingCacheJitter
return
billingCacheTTL
+
jitter
}
// billingBalanceKey generates the Redis key for user balance cache.
func
billingBalanceKey
(
userID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
billingBalanceKeyPrefix
,
userID
)
...
...
@@ -82,14 +90,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
func
(
c
*
billingCache
)
SetUserBalance
(
ctx
context
.
Context
,
userID
int64
,
balance
float64
)
error
{
key
:=
billingBalanceKey
(
userID
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
balance
,
billingCache
TTL
)
.
Err
()
return
c
.
rdb
.
Set
(
ctx
,
key
,
balance
,
jittered
TTL
()
)
.
Err
()
}
func
(
c
*
billingCache
)
DeductUserBalance
(
ctx
context
.
Context
,
userID
int64
,
amount
float64
)
error
{
key
:=
billingBalanceKey
(
userID
)
_
,
err
:=
deductBalanceScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
amount
,
int
(
billingCache
TTL
.
Seconds
()))
.
Result
()
_
,
err
:=
deductBalanceScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
amount
,
int
(
jittered
TTL
()
.
Seconds
()))
.
Result
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
log
.
Printf
(
"Warning: deduct balance cache failed for user %d: %v"
,
userID
,
err
)
return
err
}
return
nil
}
...
...
@@ -163,16 +172,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
pipe
:=
c
.
rdb
.
Pipeline
()
pipe
.
HSet
(
ctx
,
key
,
fields
)
pipe
.
Expire
(
ctx
,
key
,
billingCache
TTL
)
pipe
.
Expire
(
ctx
,
key
,
jittered
TTL
()
)
_
,
err
:=
pipe
.
Exec
(
ctx
)
return
err
}
func
(
c
*
billingCache
)
UpdateSubscriptionUsage
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
cost
float64
)
error
{
key
:=
billingSubKey
(
userID
,
groupID
)
_
,
err
:=
updateSubUsageScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
cost
,
int
(
billingCache
TTL
.
Seconds
()))
.
Result
()
_
,
err
:=
updateSubUsageScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
cost
,
int
(
jittered
TTL
()
.
Seconds
()))
.
Result
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
log
.
Printf
(
"Warning: update subscription usage cache failed for user %d group %d: %v"
,
userID
,
groupID
,
err
)
return
err
}
return
nil
}
...
...
backend/internal/repository/billing_cache_integration_test.go
View file @
a14dfb76
...
...
@@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}
}
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
func
(
s
*
BillingCacheSuite
)
TestDeductUserBalance_ErrorPropagation
()
{
tests
:=
[]
struct
{
name
string
fn
func
(
ctx
context
.
Context
,
cache
service
.
BillingCache
)
expectErr
bool
}{
{
name
:
"key_not_exists_returns_nil"
,
fn
:
func
(
ctx
context
.
Context
,
cache
service
.
BillingCache
)
{
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
err
:=
cache
.
DeductUserBalance
(
ctx
,
99999
,
1.0
)
require
.
NoError
(
s
.
T
(),
err
,
"DeductUserBalance on non-existent key should return nil"
)
},
},
{
name
:
"existing_key_deducts_successfully"
,
fn
:
func
(
ctx
context
.
Context
,
cache
service
.
BillingCache
)
{
require
.
NoError
(
s
.
T
(),
cache
.
SetUserBalance
(
ctx
,
200
,
50.0
))
err
:=
cache
.
DeductUserBalance
(
ctx
,
200
,
10.0
)
require
.
NoError
(
s
.
T
(),
err
,
"DeductUserBalance should succeed"
)
bal
,
err
:=
cache
.
GetUserBalance
(
ctx
,
200
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
40.0
,
bal
,
"余额应为 40.0"
)
},
},
{
name
:
"cancelled_context_propagates_error"
,
fn
:
func
(
ctx
context
.
Context
,
cache
service
.
BillingCache
)
{
require
.
NoError
(
s
.
T
(),
cache
.
SetUserBalance
(
ctx
,
201
,
50.0
))
cancelCtx
,
cancel
:=
context
.
WithCancel
(
ctx
)
cancel
()
// 立即取消
err
:=
cache
.
DeductUserBalance
(
cancelCtx
,
201
,
10.0
)
require
.
Error
(
s
.
T
(),
err
,
"cancelled context should propagate error"
)
},
},
}
for
_
,
tt
:=
range
tests
{
s
.
Run
(
tt
.
name
,
func
()
{
rdb
:=
testRedis
(
s
.
T
())
cache
:=
NewBillingCache
(
rdb
)
ctx
:=
context
.
Background
()
tt
.
fn
(
ctx
,
cache
)
})
}
}
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
func
(
s
*
BillingCacheSuite
)
TestUpdateSubscriptionUsage_ErrorPropagation
()
{
s
.
Run
(
"key_not_exists_returns_nil"
,
func
()
{
rdb
:=
testRedis
(
s
.
T
())
cache
:=
NewBillingCache
(
rdb
)
ctx
:=
context
.
Background
()
err
:=
cache
.
UpdateSubscriptionUsage
(
ctx
,
88888
,
77777
,
1.0
)
require
.
NoError
(
s
.
T
(),
err
,
"UpdateSubscriptionUsage on non-existent key should return nil"
)
})
s
.
Run
(
"cancelled_context_propagates_error"
,
func
()
{
rdb
:=
testRedis
(
s
.
T
())
cache
:=
NewBillingCache
(
rdb
)
ctx
:=
context
.
Background
()
data
:=
&
service
.
SubscriptionCacheData
{
Status
:
"active"
,
ExpiresAt
:
time
.
Now
()
.
Add
(
1
*
time
.
Hour
),
Version
:
1
,
}
require
.
NoError
(
s
.
T
(),
cache
.
SetSubscriptionCache
(
ctx
,
301
,
401
,
data
))
cancelCtx
,
cancel
:=
context
.
WithCancel
(
ctx
)
cancel
()
err
:=
cache
.
UpdateSubscriptionUsage
(
cancelCtx
,
301
,
401
,
1.0
)
require
.
Error
(
s
.
T
(),
err
,
"cancelled context should propagate error"
)
})
}
func
TestBillingCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
BillingCacheSuite
))
}
backend/internal/repository/billing_cache_test.go
View file @
a14dfb76
...
...
@@ -5,6 +5,7 @@ package repository
import
(
"math"
"testing"
"time"
"github.com/stretchr/testify/require"
)
...
...
@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
})
}
}
func
TestJitteredTTL
(
t
*
testing
.
T
)
{
const
(
minTTL
=
4
*
time
.
Minute
+
30
*
time
.
Second
// 270s = 5min - 30s
maxTTL
=
5
*
time
.
Minute
+
30
*
time
.
Second
// 330s = 5min + 30s
)
for
i
:=
0
;
i
<
200
;
i
++
{
ttl
:=
jitteredTTL
()
require
.
GreaterOrEqual
(
t
,
ttl
,
minTTL
,
"jitteredTTL() 返回值低于下限: %v"
,
ttl
)
require
.
LessOrEqual
(
t
,
ttl
,
maxTTL
,
"jitteredTTL() 返回值超过上限: %v"
,
ttl
)
}
}
func
TestJitteredTTL_HasVariation
(
t
*
testing
.
T
)
{
// 多次调用应该产生不同的值(验证抖动存在)
seen
:=
make
(
map
[
time
.
Duration
]
struct
{},
50
)
for
i
:=
0
;
i
<
50
;
i
++
{
seen
[
jitteredTTL
()]
=
struct
{}{}
}
// 50 次调用中应该至少有 2 个不同的值
require
.
Greater
(
t
,
len
(
seen
),
1
,
"jitteredTTL() 应产生不同的 TTL 值"
)
}
backend/internal/repository/group_repo.go
View file @
a14dfb76
...
...
@@ -183,7 +183,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
q
=
q
.
Where
(
group
.
IsExclusiveEQ
(
*
isExclusive
))
}
total
,
err
:=
q
.
Count
(
ctx
)
total
,
err
:=
q
.
Clone
()
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
...
...
backend/internal/repository/promo_code_repo.go
View file @
a14dfb76
...
...
@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
q
=
q
.
Where
(
promocode
.
CodeContainsFold
(
search
))
}
total
,
err
:=
q
.
Count
(
ctx
)
total
,
err
:=
q
.
Clone
()
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
...
...
@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
q
:=
r
.
client
.
PromoCodeUsage
.
Query
()
.
Where
(
promocodeusage
.
PromoCodeIDEQ
(
promoCodeID
))
total
,
err
:=
q
.
Count
(
ctx
)
total
,
err
:=
q
.
Clone
()
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
...
...
backend/internal/repository/usage_log_repo.go
View file @
a14dfb76
...
...
@@ -24,6 +24,22 @@ import (
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var
dateFormatWhitelist
=
map
[
string
]
string
{
"hour"
:
"YYYY-MM-DD HH24:00"
,
"day"
:
"YYYY-MM-DD"
,
"week"
:
"IYYY-IW"
,
"month"
:
"YYYY-MM"
,
}
// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值
func
safeDateFormat
(
granularity
string
)
string
{
if
f
,
ok
:=
dateFormatWhitelist
[
granularity
];
ok
{
return
f
}
return
"YYYY-MM-DD"
}
type
usageLogRepository
struct
{
client
*
dbent
.
Client
sql
sqlExecutor
...
...
@@ -564,7 +580,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}
func
(
r
*
usageLogRepository
)
ListByUserAndTimeRange
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC
LIMIT 10000
"
logs
,
err
:=
r
.
queryUsageLogs
(
ctx
,
query
,
userID
,
startTime
,
endTime
)
return
logs
,
nil
,
err
}
...
...
@@ -810,19 +826,19 @@ func resolveUsageStatsTimezone() string {
}
func
(
r
*
usageLogRepository
)
ListByAPIKeyAndTimeRange
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC
LIMIT 10000
"
logs
,
err
:=
r
.
queryUsageLogs
(
ctx
,
query
,
apiKeyID
,
startTime
,
endTime
)
return
logs
,
nil
,
err
}
func
(
r
*
usageLogRepository
)
ListByAccountAndTimeRange
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC
LIMIT 10000
"
logs
,
err
:=
r
.
queryUsageLogs
(
ctx
,
query
,
accountID
,
startTime
,
endTime
)
return
logs
,
nil
,
err
}
func
(
r
*
usageLogRepository
)
ListByModelAndTimeRange
(
ctx
context
.
Context
,
modelName
string
,
startTime
,
endTime
time
.
Time
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query
:=
"SELECT "
+
usageLogSelectColumns
+
" FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC
LIMIT 10000
"
logs
,
err
:=
r
.
queryUsageLogs
(
ctx
,
query
,
modelName
,
startTime
,
endTime
)
return
logs
,
nil
,
err
}
...
...
@@ -908,10 +924,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func
(
r
*
usageLogRepository
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
(
results
[]
APIKeyUsageTrendPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
dateFormat
:=
safeDateFormat
(
granularity
)
query
:=
fmt
.
Sprintf
(
`
WITH top_keys AS (
...
...
@@ -966,10 +979,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
// GetUserUsageTrend returns usage trend data grouped by user and date
func
(
r
*
usageLogRepository
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
(
results
[]
UserUsageTrendPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
dateFormat
:=
safeDateFormat
(
granularity
)
query
:=
fmt
.
Sprintf
(
`
WITH top_users AS (
...
...
@@ -1228,10 +1238,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func
(
r
*
usageLogRepository
)
GetUserUsageTrendByUserID
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
,
granularity
string
)
(
results
[]
TrendDataPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
dateFormat
:=
safeDateFormat
(
granularity
)
query
:=
fmt
.
Sprintf
(
`
SELECT
...
...
@@ -1369,13 +1376,22 @@ type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type
BatchUserUsageStats
=
usagestats
.
BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func
(
r
*
usageLogRepository
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
BatchUserUsageStats
,
error
)
{
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
// If startTime is zero, defaults to 30 days ago.
func
(
r
*
usageLogRepository
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
BatchUserUsageStats
,
error
)
{
result
:=
make
(
map
[
int64
]
*
BatchUserUsageStats
)
if
len
(
userIDs
)
==
0
{
return
result
,
nil
}
// 默认最近 30 天
if
startTime
.
IsZero
()
{
startTime
=
time
.
Now
()
.
AddDate
(
0
,
0
,
-
30
)
}
if
endTime
.
IsZero
()
{
endTime
=
time
.
Now
()
}
for
_
,
id
:=
range
userIDs
{
result
[
id
]
=
&
BatchUserUsageStats
{
UserID
:
id
}
}
...
...
@@ -1383,10 +1399,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
query
:=
`
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE user_id = ANY($1)
WHERE user_id = ANY($1)
AND created_at >= $2 AND created_at < $3
GROUP BY user_id
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
userIDs
))
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
userIDs
)
,
startTime
,
endTime
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1443,13 +1459,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
// BatchAPIKeyUsageStats represents usage stats for a single API key
type
BatchAPIKeyUsageStats
=
usagestats
.
BatchAPIKeyUsageStats
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func
(
r
*
usageLogRepository
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
BatchAPIKeyUsageStats
,
error
)
{
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
// If startTime is zero, defaults to 30 days ago.
func
(
r
*
usageLogRepository
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
BatchAPIKeyUsageStats
,
error
)
{
result
:=
make
(
map
[
int64
]
*
BatchAPIKeyUsageStats
)
if
len
(
apiKeyIDs
)
==
0
{
return
result
,
nil
}
// 默认最近 30 天
if
startTime
.
IsZero
()
{
startTime
=
time
.
Now
()
.
AddDate
(
0
,
0
,
-
30
)
}
if
endTime
.
IsZero
()
{
endTime
=
time
.
Now
()
}
for
_
,
id
:=
range
apiKeyIDs
{
result
[
id
]
=
&
BatchAPIKeyUsageStats
{
APIKeyID
:
id
}
}
...
...
@@ -1457,10 +1482,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
query
:=
`
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE api_key_id = ANY($1)
WHERE api_key_id = ANY($1)
AND created_at >= $2 AND created_at < $3
GROUP BY api_key_id
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
apiKeyIDs
))
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
apiKeyIDs
)
,
startTime
,
endTime
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1516,10 +1541,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
// GetUsageTrendWithFilters returns usage trend data with optional filters
func
(
r
*
usageLogRepository
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
TrendDataPoint
,
err
error
)
{
dateFormat
:=
"YYYY-MM-DD"
if
granularity
==
"hour"
{
dateFormat
=
"YYYY-MM-DD HH24:00"
}
dateFormat
:=
safeDateFormat
(
granularity
)
query
:=
fmt
.
Sprintf
(
`
SELECT
...
...
backend/internal/repository/usage_log_repo_integration_test.go
View file @
a14dfb76
...
...
@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
s
.
createUsageLog
(
user1
,
apiKey1
,
account
,
10
,
20
,
0.5
,
time
.
Now
())
s
.
createUsageLog
(
user2
,
apiKey2
,
account
,
15
,
25
,
0.6
,
time
.
Now
())
stats
,
err
:=
s
.
repo
.
GetBatchUserUsageStats
(
s
.
ctx
,
[]
int64
{
user1
.
ID
,
user2
.
ID
})
stats
,
err
:=
s
.
repo
.
GetBatchUserUsageStats
(
s
.
ctx
,
[]
int64
{
user1
.
ID
,
user2
.
ID
}
,
time
.
Time
{},
time
.
Time
{}
)
s
.
Require
()
.
NoError
(
err
,
"GetBatchUserUsageStats"
)
s
.
Require
()
.
Len
(
stats
,
2
)
s
.
Require
()
.
NotNil
(
stats
[
user1
.
ID
])
...
...
@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
}
func
(
s
*
UsageLogRepoSuite
)
TestGetBatchUserUsageStats_Empty
()
{
stats
,
err
:=
s
.
repo
.
GetBatchUserUsageStats
(
s
.
ctx
,
[]
int64
{})
stats
,
err
:=
s
.
repo
.
GetBatchUserUsageStats
(
s
.
ctx
,
[]
int64
{}
,
time
.
Time
{},
time
.
Time
{}
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Empty
(
stats
)
}
...
...
@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
s
.
createUsageLog
(
user
,
apiKey1
,
account
,
10
,
20
,
0.5
,
time
.
Now
())
s
.
createUsageLog
(
user
,
apiKey2
,
account
,
15
,
25
,
0.6
,
time
.
Now
())
stats
,
err
:=
s
.
repo
.
GetBatchAPIKeyUsageStats
(
s
.
ctx
,
[]
int64
{
apiKey1
.
ID
,
apiKey2
.
ID
})
stats
,
err
:=
s
.
repo
.
GetBatchAPIKeyUsageStats
(
s
.
ctx
,
[]
int64
{
apiKey1
.
ID
,
apiKey2
.
ID
}
,
time
.
Time
{},
time
.
Time
{}
)
s
.
Require
()
.
NoError
(
err
,
"GetBatchAPIKeyUsageStats"
)
s
.
Require
()
.
Len
(
stats
,
2
)
}
func
(
s
*
UsageLogRepoSuite
)
TestGetBatchApiKeyUsageStats_Empty
()
{
stats
,
err
:=
s
.
repo
.
GetBatchAPIKeyUsageStats
(
s
.
ctx
,
[]
int64
{})
stats
,
err
:=
s
.
repo
.
GetBatchAPIKeyUsageStats
(
s
.
ctx
,
[]
int64
{}
,
time
.
Time
{},
time
.
Time
{}
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Empty
(
stats
)
}
...
...
backend/internal/repository/usage_log_repo_unit_test.go
0 → 100644
View file @
a14dfb76
//go:build unit
package
repository
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestSafeDateFormat
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
granularity
string
expected
string
}{
// 合法值
{
"hour"
,
"hour"
,
"YYYY-MM-DD HH24:00"
},
{
"day"
,
"day"
,
"YYYY-MM-DD"
},
{
"week"
,
"week"
,
"IYYY-IW"
},
{
"month"
,
"month"
,
"YYYY-MM"
},
// 非法值回退到默认
{
"空字符串"
,
""
,
"YYYY-MM-DD"
},
{
"未知粒度 year"
,
"year"
,
"YYYY-MM-DD"
},
{
"未知粒度 minute"
,
"minute"
,
"YYYY-MM-DD"
},
// 恶意字符串
{
"SQL 注入尝试"
,
"'; DROP TABLE users; --"
,
"YYYY-MM-DD"
},
{
"带引号"
,
"day'"
,
"YYYY-MM-DD"
},
{
"带括号"
,
"day)"
,
"YYYY-MM-DD"
},
{
"Unicode"
,
"日"
,
"YYYY-MM-DD"
},
}
for
_
,
tc
:=
range
tests
{
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
safeDateFormat
(
tc
.
granularity
)
require
.
Equal
(
t
,
tc
.
expected
,
got
,
"safeDateFormat(%q)"
,
tc
.
granularity
)
})
}
}
backend/internal/server/api_contract_test.go
View file @
a14dfb76
...
...
@@ -592,13 +592,13 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode
:
config
.
RunModeStandard
,
}
userService
:=
service
.
NewUserService
(
userRepo
,
nil
)
userService
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
)
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
userRepo
,
groupRepo
,
userSubRepo
,
nil
,
apiKeyCache
,
cfg
)
usageRepo
:=
newStubUsageLogRepo
()
usageService
:=
service
.
NewUsageService
(
usageRepo
,
userRepo
,
nil
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepo
,
userSubRepo
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepo
,
userSubRepo
,
nil
,
cfg
)
subscriptionHandler
:=
handler
.
NewSubscriptionHandler
(
subscriptionService
)
redeemService
:=
service
.
NewRedeemService
(
redeemRepo
,
userRepo
,
subscriptionService
,
nil
,
nil
,
nil
,
nil
)
...
...
@@ -1602,11 +1602,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/middleware/admin_auth.go
View file @
a14dfb76
...
...
@@ -176,6 +176,12 @@ func validateJWTForAdmin(
return
false
}
// 校验 TokenVersion,确保管理员改密后旧 token 失效
if
claims
.
TokenVersion
!=
user
.
TokenVersion
{
AbortWithError
(
c
,
401
,
"TOKEN_REVOKED"
,
"Token has been revoked (password changed)"
)
return
false
}
// 检查管理员权限
if
!
user
.
IsAdmin
()
{
AbortWithError
(
c
,
403
,
"FORBIDDEN"
,
"Admin access required"
)
...
...
backend/internal/server/middleware/admin_auth_test.go
0 → 100644
View file @
a14dfb76
//go:build unit
package
middleware
import
(
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestAdminAuthJWTValidatesTokenVersion
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
Secret
:
"test-secret"
,
ExpireHour
:
1
}}
authService
:=
service
.
NewAuthService
(
nil
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
)
admin
:=
&
service
.
User
{
ID
:
1
,
Email
:
"admin@example.com"
,
Role
:
service
.
RoleAdmin
,
Status
:
service
.
StatusActive
,
TokenVersion
:
2
,
Concurrency
:
1
,
}
userRepo
:=
&
stubUserRepo
{
getByID
:
func
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
{
if
id
!=
admin
.
ID
{
return
nil
,
service
.
ErrUserNotFound
}
clone
:=
*
admin
return
&
clone
,
nil
},
}
userService
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAdminAuthMiddleware
(
authService
,
userService
,
nil
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
t
.
Run
(
"token_version_mismatch_rejected"
,
func
(
t
*
testing
.
T
)
{
token
,
err
:=
authService
.
GenerateToken
(
&
service
.
User
{
ID
:
admin
.
ID
,
Email
:
admin
.
Email
,
Role
:
admin
.
Role
,
TokenVersion
:
admin
.
TokenVersion
-
1
,
})
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"TOKEN_REVOKED"
)
})
t
.
Run
(
"token_version_match_allows"
,
func
(
t
*
testing
.
T
)
{
token
,
err
:=
authService
.
GenerateToken
(
&
service
.
User
{
ID
:
admin
.
ID
,
Email
:
admin
.
Email
,
Role
:
admin
.
Role
,
TokenVersion
:
admin
.
TokenVersion
,
})
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
t
.
Run
(
"websocket_token_version_mismatch_rejected"
,
func
(
t
*
testing
.
T
)
{
token
,
err
:=
authService
.
GenerateToken
(
&
service
.
User
{
ID
:
admin
.
ID
,
Email
:
admin
.
Email
,
Role
:
admin
.
Role
,
TokenVersion
:
admin
.
TokenVersion
-
1
,
})
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Upgrade"
,
"websocket"
)
req
.
Header
.
Set
(
"Connection"
,
"Upgrade"
)
req
.
Header
.
Set
(
"Sec-WebSocket-Protocol"
,
"sub2api-admin, jwt."
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusUnauthorized
,
w
.
Code
)
require
.
Contains
(
t
,
w
.
Body
.
String
(),
"TOKEN_REVOKED"
)
})
t
.
Run
(
"websocket_token_version_match_allows"
,
func
(
t
*
testing
.
T
)
{
token
,
err
:=
authService
.
GenerateToken
(
&
service
.
User
{
ID
:
admin
.
ID
,
Email
:
admin
.
Email
,
Role
:
admin
.
Role
,
TokenVersion
:
admin
.
TokenVersion
,
})
require
.
NoError
(
t
,
err
)
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"Upgrade"
,
"websocket"
)
req
.
Header
.
Set
(
"Connection"
,
"Upgrade"
)
req
.
Header
.
Set
(
"Sec-WebSocket-Protocol"
,
"sub2api-admin, jwt."
+
token
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
})
}
type
stubUserRepo
struct
{
getByID
func
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
}
func
(
s
*
stubUserRepo
)
Create
(
ctx
context
.
Context
,
user
*
service
.
User
)
error
{
panic
(
"unexpected Create call"
)
}
func
(
s
*
stubUserRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
User
,
error
)
{
if
s
.
getByID
==
nil
{
panic
(
"GetByID not stubbed"
)
}
return
s
.
getByID
(
ctx
,
id
)
}
func
(
s
*
stubUserRepo
)
GetByEmail
(
ctx
context
.
Context
,
email
string
)
(
*
service
.
User
,
error
)
{
panic
(
"unexpected GetByEmail call"
)
}
func
(
s
*
stubUserRepo
)
GetFirstAdmin
(
ctx
context
.
Context
)
(
*
service
.
User
,
error
)
{
panic
(
"unexpected GetFirstAdmin call"
)
}
func
(
s
*
stubUserRepo
)
Update
(
ctx
context
.
Context
,
user
*
service
.
User
)
error
{
panic
(
"unexpected Update call"
)
}
func
(
s
*
stubUserRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
stubUserRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
stubUserRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
service
.
UserListFilters
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
stubUserRepo
)
UpdateBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
panic
(
"unexpected UpdateBalance call"
)
}
func
(
s
*
stubUserRepo
)
DeductBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
panic
(
"unexpected DeductBalance call"
)
}
func
(
s
*
stubUserRepo
)
UpdateConcurrency
(
ctx
context
.
Context
,
id
int64
,
amount
int
)
error
{
panic
(
"unexpected UpdateConcurrency call"
)
}
func
(
s
*
stubUserRepo
)
ExistsByEmail
(
ctx
context
.
Context
,
email
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByEmail call"
)
}
func
(
s
*
stubUserRepo
)
RemoveGroupFromAllowedGroups
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
panic
(
"unexpected RemoveGroupFromAllowedGroups call"
)
}
func
(
s
*
stubUserRepo
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
panic
(
"unexpected UpdateTotpSecret call"
)
}
func
(
s
*
stubUserRepo
)
EnableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected EnableTotp call"
)
}
func
(
s
*
stubUserRepo
)
DisableTotp
(
ctx
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected DisableTotp call"
)
}
backend/internal/server/middleware/api_key_auth.go
View file @
a14dfb76
...
...
@@ -3,7 +3,6 @@ package middleware
import
(
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -134,7 +133,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
isSubscriptionType
:=
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
if
isSubscriptionType
&&
subscriptionService
!=
nil
{
// 订阅模式:
验证订阅
// 订阅模式:
获取订阅(L1 缓存 + singleflight)
subscription
,
err
:=
subscriptionService
.
GetActiveSubscription
(
c
.
Request
.
Context
(),
apiKey
.
User
.
ID
,
...
...
@@ -145,30 +144,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 验证订阅状态(是否过期、暂停等)
if
err
:=
subscriptionService
.
ValidateSubscription
(
c
.
Request
.
Context
(),
subscription
);
err
!=
nil
{
AbortWithError
(
c
,
403
,
"SUBSCRIPTION_INVALID"
,
err
.
Error
())
return
}
// 激活滑动窗口(首次使用时)
if
err
:=
subscriptionService
.
CheckAndActivateWindow
(
c
.
Request
.
Context
(),
subscription
);
err
!=
nil
{
log
.
Printf
(
"Failed to activate subscription windows: %v"
,
err
)
}
// 检查并重置过期窗口
if
err
:=
subscriptionService
.
CheckAndResetWindows
(
c
.
Request
.
Context
(),
subscription
);
err
!=
nil
{
log
.
Printf
(
"Failed to reset subscription windows: %v"
,
err
)
}
// 预检查用量限制(使用0作为额外费用进行预检查)
if
err
:=
subscriptionService
.
CheckUsageLimits
(
c
.
Request
.
Context
(),
subscription
,
apiKey
.
Group
,
0
);
err
!=
nil
{
AbortWithError
(
c
,
429
,
"USAGE_LIMIT_EXCEEDED"
,
err
.
Error
())
// 合并验证 + 限额检查(纯内存操作)
needsMaintenance
,
err
:=
subscriptionService
.
ValidateAndCheckLimits
(
subscription
,
apiKey
.
Group
)
if
err
!=
nil
{
code
:=
"SUBSCRIPTION_INVALID"
status
:=
403
if
errors
.
Is
(
err
,
service
.
ErrDailyLimitExceeded
)
||
errors
.
Is
(
err
,
service
.
ErrWeeklyLimitExceeded
)
||
errors
.
Is
(
err
,
service
.
ErrMonthlyLimitExceeded
)
{
code
=
"USAGE_LIMIT_EXCEEDED"
status
=
429
}
AbortWithError
(
c
,
status
,
code
,
err
.
Error
())
return
}
// 将订阅信息存入上下文
c
.
Set
(
string
(
ContextKeySubscription
),
subscription
)
// 窗口维护异步化(不阻塞请求)
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if
needsMaintenance
{
maintenanceCopy
:=
*
subscription
go
subscriptionService
.
DoWindowMaintenance
(
&
maintenanceCopy
)
}
}
else
{
// 余额模式:检查用户余额
if
apiKey
.
User
.
Balance
<=
0
{
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
a14dfb76
...
...
@@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t
.
Run
(
"simple_mode_bypasses_quota_check"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
&
stubUserSubscriptionRepo
{},
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
...
...
@@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
resetWeekly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetMonthly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
}
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
)
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
subscriptionRepo
,
nil
,
cfg
)
router
:=
newAuthTestRouter
(
apiKeyService
,
subscriptionService
,
cfg
)
w
:=
httptest
.
NewRecorder
()
...
...
backend/internal/server/middleware/cors.go
View file @
a14dfb76
...
...
@@ -72,6 +72,7 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
"Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Methods"
,
"POST, OPTIONS, GET, PUT, DELETE, PATCH"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Max-Age"
,
"86400"
)
// 处理预检请求
if
c
.
Request
.
Method
==
http
.
MethodOptions
{
...
...
backend/internal/service/account_usage_service.go
View file @
a14dfb76
...
...
@@ -36,8 +36,8 @@ type UsageLogRepository interface {
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
// User dashboard stats
GetUserDashboardStats
(
ctx
context
.
Context
,
userID
int64
)
(
*
usagestats
.
UserDashboardStats
,
error
)
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
a14dfb76
...
...
@@ -1582,6 +1582,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return
changed
,
nil
}
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func
(
s
*
AntigravityGatewayService
)
ForwardUpstream
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
// 获取上游配置
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 解析请求获取模型信息
var
claudeReq
antigravity
.
ClaudeRequest
if
err
:=
json
.
Unmarshal
(
body
,
&
claudeReq
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse claude request: %w"
,
err
)
}
if
strings
.
TrimSpace
(
claudeReq
.
Model
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing model"
)
}
originalModel
:=
claudeReq
.
Model
billingModel
:=
originalModel
// 构建上游请求 URL
upstreamURL
:=
baseURL
+
"/v1/messages"
// 创建请求
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create upstream request: %w"
,
err
)
}
// 设置请求头
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
// Claude API 兼容
// 透传 Claude 相关 headers
if
v
:=
c
.
GetHeader
(
"anthropic-version"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
v
)
}
if
v
:=
c
.
GetHeader
(
"anthropic-beta"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
v
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
log
.
Printf
(
"%s upstream request failed: %v"
,
prefix
,
err
)
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 429 错误时标记账号限流
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
AntigravityQuotaScopeClaude
)
}
// 透传上游错误
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
resp
.
StatusCode
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
return
&
ForwardResult
{
Model
:
billingModel
,
},
nil
}
// 处理成功响应(流式/非流式)
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
if
claudeReq
.
Stream
{
// 流式响应:透传
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
c
.
Status
(
http
.
StatusOK
)
usage
,
firstTokenMs
=
s
.
streamUpstreamResponse
(
c
,
resp
,
startTime
)
}
else
{
// 非流式响应:直接透传
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"read upstream response: %w"
,
err
)
}
// 提取 usage
usage
=
s
.
extractClaudeUsage
(
respBody
)
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
http
.
StatusOK
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
}
// 构建计费结果
duration
:=
time
.
Since
(
startTime
)
log
.
Printf
(
"%s status=success duration_ms=%d"
,
prefix
,
duration
.
Milliseconds
())
return
&
ForwardResult
{
Model
:
billingModel
,
Stream
:
claudeReq
.
Stream
,
Duration
:
duration
,
FirstTokenMs
:
firstTokenMs
,
Usage
:
ClaudeUsage
{
InputTokens
:
usage
.
InputTokens
,
OutputTokens
:
usage
.
OutputTokens
,
CacheReadInputTokens
:
usage
.
CacheReadInputTokens
,
CacheCreationInputTokens
:
usage
.
CacheCreationInputTokens
,
},
},
nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func
(
s
*
AntigravityGatewayService
)
streamUpstreamResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
(
*
ClaudeUsage
,
*
int
)
{
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
var
firstTokenRecorded
bool
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
buf
:=
make
([]
byte
,
0
,
64
*
1024
)
scanner
.
Buffer
(
buf
,
1024
*
1024
)
for
scanner
.
Scan
()
{
line
:=
scanner
.
Bytes
()
// 记录首 token 时间
if
!
firstTokenRecorded
&&
len
(
line
)
>
0
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
firstTokenRecorded
=
true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if
bytes
.
HasPrefix
(
line
,
[]
byte
(
"data: "
))
{
dataStr
:=
bytes
.
TrimPrefix
(
line
,
[]
byte
(
"data: "
))
var
event
map
[
string
]
any
if
json
.
Unmarshal
(
dataStr
,
&
event
)
==
nil
{
if
u
,
ok
:=
event
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
}
}
// 透传行
_
,
_
=
c
.
Writer
.
Write
(
line
)
_
,
_
=
c
.
Writer
.
Write
([]
byte
(
"
\n
"
))
c
.
Writer
.
Flush
()
}
return
usage
,
firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func
(
s
*
AntigravityGatewayService
)
extractClaudeUsage
(
body
[]
byte
)
*
ClaudeUsage
{
usage
:=
&
ClaudeUsage
{}
var
resp
map
[
string
]
any
if
json
.
Unmarshal
(
body
,
&
resp
)
!=
nil
{
return
usage
}
if
u
,
ok
:=
resp
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
return
usage
}
// ForwardGemini 转发 Gemini 协议请求
func
(
s
*
AntigravityGatewayService
)
ForwardGemini
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
string
,
action
string
,
stream
bool
,
body
[]
byte
,
isStickySession
bool
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
...
...
@@ -1613,7 +1815,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Usage
:
ClaudeUsage
{},
Model
:
originalModel
,
Stream
:
false
,
Duration
:
time
.
Since
(
time
.
Now
()
),
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
nil
,
},
nil
default
:
...
...
@@ -2288,7 +2490,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
scanBuf
:=
getSSEScannerBuf64K
()
scanner
.
Buffer
(
scanBuf
[
:
0
],
maxLineSize
)
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
...
...
@@ -2309,7 +2512,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
go
func
(
scanBuf
*
sseScannerBuf64K
)
{
defer
putSSEScannerBuf64K
(
scanBuf
)
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
...
...
@@ -2320,7 +2524,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
}(
scanBuf
)
defer
close
(
done
)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
...
...
@@ -2445,7 +2649,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
scanBuf
:=
getSSEScannerBuf64K
()
scanner
.
Buffer
(
scanBuf
[
:
0
],
maxLineSize
)
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
...
...
@@ -2473,7 +2678,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
go
func
(
scanBuf
*
sseScannerBuf64K
)
{
defer
putSSEScannerBuf64K
(
scanBuf
)
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
...
...
@@ -2484,7 +2690,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
}(
scanBuf
)
defer
close
(
done
)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
...
...
@@ -2888,7 +3094,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
scanBuf
:=
getSSEScannerBuf64K
()
scanner
.
Buffer
(
scanBuf
[
:
0
],
maxLineSize
)
var
firstTokenMs
*
int
var
last
map
[
string
]
any
...
...
@@ -2914,7 +3121,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
go
func
(
scanBuf
*
sseScannerBuf64K
)
{
defer
putSSEScannerBuf64K
(
scanBuf
)
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
...
...
@@ -2925,7 +3133,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
}(
scanBuf
)
defer
close
(
done
)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
...
...
@@ -3068,7 +3276,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
scanBuf
:=
getSSEScannerBuf64K
()
scanner
.
Buffer
(
scanBuf
[
:
0
],
maxLineSize
)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage
:=
func
(
agUsage
*
antigravity
.
ClaudeUsage
)
*
ClaudeUsage
{
...
...
@@ -3100,7 +3309,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
go
func
(
scanBuf
*
sseScannerBuf64K
)
{
defer
putSSEScannerBuf64K
(
scanBuf
)
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
...
...
@@ -3111,7 +3321,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
}(
scanBuf
)
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment