Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
005d0c5f
Unverified
Commit
005d0c5f
authored
Mar 06, 2026
by
Wesley Liddick
Committed by
GitHub
Mar 06, 2026
Browse files
Merge pull request #815 from mt21625457/pr/openai-user-group-rate-upstream
fix(openai): 统一专属倍率计费链路并补齐回归测试
parents
8aaaeb29
230f8abd
Changes
8
Hide whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
005d0c5f
...
...
@@ -164,7 +164,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
digestSessionStore
:=
service
.
NewDigestSessionStore
()
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
userGroupRateRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
,
sessionLimitCache
,
rpmCache
,
digestSessionStore
)
openAITokenProvider
:=
service
.
NewOpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
userGroupRateRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
opsSystemLogSink
:=
service
.
ProvideOpsSystemLogSink
(
opsRepository
)
opsService
:=
service
.
NewOpsService
(
opsRepository
,
settingRepository
,
configConfig
,
accountRepository
,
userRepository
,
concurrencyService
,
gatewayService
,
openAIGatewayService
,
geminiMessagesCompatService
,
antigravityGatewayService
,
opsSystemLogSink
)
...
...
backend/internal/service/gateway_service.go
View file @
005d0c5f
...
...
@@ -501,33 +501,34 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou
// GatewayService handles API gateway operations
type
GatewayService
struct
{
accountRepo
AccountRepository
groupRepo
GroupRepository
usageLogRepo
UsageLogRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
userGroupRateRepo
UserGroupRateRepository
cache
GatewayCache
digestStore
*
DigestSessionStore
cfg
*
config
.
Config
schedulerSnapshot
*
SchedulerSnapshotService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
identityService
*
IdentityService
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
concurrencyService
*
ConcurrencyService
claudeTokenProvider
*
ClaudeTokenProvider
sessionLimitCache
SessionLimitCache
// 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
rpmCache
RPMCache
// RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
userGroupRateCache
*
gocache
.
Cache
userGroupRateSF
singleflight
.
Group
modelsListCache
*
gocache
.
Cache
modelsListCacheTTL
time
.
Duration
responseHeaderFilter
*
responseheaders
.
CompiledHeaderFilter
debugModelRouting
atomic
.
Bool
debugClaudeMimic
atomic
.
Bool
accountRepo
AccountRepository
groupRepo
GroupRepository
usageLogRepo
UsageLogRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
userGroupRateRepo
UserGroupRateRepository
cache
GatewayCache
digestStore
*
DigestSessionStore
cfg
*
config
.
Config
schedulerSnapshot
*
SchedulerSnapshotService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
identityService
*
IdentityService
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
concurrencyService
*
ConcurrencyService
claudeTokenProvider
*
ClaudeTokenProvider
sessionLimitCache
SessionLimitCache
// 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
rpmCache
RPMCache
// RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
userGroupRateResolver
*
userGroupRateResolver
userGroupRateCache
*
gocache
.
Cache
userGroupRateSF
singleflight
.
Group
modelsListCache
*
gocache
.
Cache
modelsListCacheTTL
time
.
Duration
responseHeaderFilter
*
responseheaders
.
CompiledHeaderFilter
debugModelRouting
atomic
.
Bool
debugClaudeMimic
atomic
.
Bool
}
// NewGatewayService creates a new GatewayService
...
...
@@ -582,6 +583,13 @@ func NewGatewayService(
modelsListCacheTTL
:
modelsListTTL
,
responseHeaderFilter
:
compileResponseHeaderFilter
(
cfg
),
}
svc
.
userGroupRateResolver
=
newUserGroupRateResolver
(
userGroupRateRepo
,
svc
.
userGroupRateCache
,
userGroupRateTTL
,
&
svc
.
userGroupRateSF
,
"service.gateway"
,
)
svc
.
debugModelRouting
.
Store
(
parseDebugEnvBool
(
os
.
Getenv
(
"SUB2API_DEBUG_MODEL_ROUTING"
)))
svc
.
debugClaudeMimic
.
Store
(
parseDebugEnvBool
(
os
.
Getenv
(
"SUB2API_DEBUG_CLAUDE_MIMIC"
)))
return
svc
...
...
@@ -6336,63 +6344,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
}
func
(
s
*
GatewayService
)
getUserGroupRateMultiplier
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
groupDefaultMultiplier
float64
)
float64
{
if
s
==
nil
||
userID
<=
0
||
groupID
<=
0
{
return
groupDefaultMultiplier
}
key
:=
fmt
.
Sprintf
(
"%d:%d"
,
userID
,
groupID
)
if
s
.
userGroupRateCache
!=
nil
{
if
cached
,
ok
:=
s
.
userGroupRateCache
.
Get
(
key
);
ok
{
if
multiplier
,
castOK
:=
cached
.
(
float64
);
castOK
{
userGroupRateCacheHitTotal
.
Add
(
1
)
return
multiplier
}
}
}
if
s
.
userGroupRateRepo
==
nil
{
return
groupDefaultMultiplier
}
userGroupRateCacheMissTotal
.
Add
(
1
)
value
,
err
,
shared
:=
s
.
userGroupRateSF
.
Do
(
key
,
func
()
(
any
,
error
)
{
if
s
.
userGroupRateCache
!=
nil
{
if
cached
,
ok
:=
s
.
userGroupRateCache
.
Get
(
key
);
ok
{
if
multiplier
,
castOK
:=
cached
.
(
float64
);
castOK
{
userGroupRateCacheHitTotal
.
Add
(
1
)
return
multiplier
,
nil
}
}
}
userGroupRateCacheLoadTotal
.
Add
(
1
)
userRate
,
repoErr
:=
s
.
userGroupRateRepo
.
GetByUserAndGroup
(
ctx
,
userID
,
groupID
)
if
repoErr
!=
nil
{
return
nil
,
repoErr
}
multiplier
:=
groupDefaultMultiplier
if
userRate
!=
nil
{
multiplier
=
*
userRate
}
if
s
.
userGroupRateCache
!=
nil
{
s
.
userGroupRateCache
.
Set
(
key
,
multiplier
,
resolveUserGroupRateCacheTTL
(
s
.
cfg
))
}
return
multiplier
,
nil
})
if
shared
{
userGroupRateCacheSFSharedTotal
.
Add
(
1
)
}
if
err
!=
nil
{
userGroupRateCacheFallbackTotal
.
Add
(
1
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"get user group rate failed, fallback to group default: user=%d group=%d err=%v"
,
userID
,
groupID
,
err
)
if
s
==
nil
{
return
groupDefaultMultiplier
}
multiplier
,
ok
:=
value
.
(
float64
)
if
!
ok
{
userGroupRateCacheFallbackTotal
.
Add
(
1
)
return
groupDefaultMultiplier
resolver
:=
s
.
userGroupRateResolver
if
resolver
==
nil
{
resolver
=
newUserGroupRateResolver
(
s
.
userGroupRateRepo
,
s
.
userGroupRateCache
,
resolveUserGroupRateCacheTTL
(
s
.
cfg
),
&
s
.
userGroupRateSF
,
"service.gateway"
,
)
}
return
m
ultiplier
return
resolver
.
Resolve
(
ctx
,
userID
,
groupID
,
groupDefaultM
ultiplier
)
}
// RecordUsageInput 记录使用量的输入参数
...
...
backend/internal/service/openai_gateway_record_usage_test.go
0 → 100644
View file @
005d0c5f
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type
openAIRecordUsageLogRepoStub
struct
{
UsageLogRepository
inserted
bool
err
error
calls
int
lastLog
*
UsageLog
}
func
(
s
*
openAIRecordUsageLogRepoStub
)
Create
(
ctx
context
.
Context
,
log
*
UsageLog
)
(
bool
,
error
)
{
s
.
calls
++
s
.
lastLog
=
log
return
s
.
inserted
,
s
.
err
}
type
openAIRecordUsageUserRepoStub
struct
{
UserRepository
deductCalls
int
deductErr
error
lastAmount
float64
}
func
(
s
*
openAIRecordUsageUserRepoStub
)
DeductBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
s
.
deductCalls
++
s
.
lastAmount
=
amount
return
s
.
deductErr
}
type
openAIRecordUsageSubRepoStub
struct
{
UserSubscriptionRepository
incrementCalls
int
incrementErr
error
}
func
(
s
*
openAIRecordUsageSubRepoStub
)
IncrementUsage
(
ctx
context
.
Context
,
id
int64
,
costUSD
float64
)
error
{
s
.
incrementCalls
++
return
s
.
incrementErr
}
type
openAIRecordUsageAPIKeyQuotaStub
struct
{
quotaCalls
int
rateLimitCalls
int
err
error
lastAmount
float64
}
func
(
s
*
openAIRecordUsageAPIKeyQuotaStub
)
UpdateQuotaUsed
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
{
s
.
quotaCalls
++
s
.
lastAmount
=
cost
return
s
.
err
}
func
(
s
*
openAIRecordUsageAPIKeyQuotaStub
)
UpdateRateLimitUsage
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
{
s
.
rateLimitCalls
++
s
.
lastAmount
=
cost
return
s
.
err
}
type
openAIUserGroupRateRepoStub
struct
{
UserGroupRateRepository
rate
*
float64
err
error
calls
int
}
func
(
s
*
openAIUserGroupRateRepoStub
)
GetByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
float64
,
error
)
{
s
.
calls
++
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
return
s
.
rate
,
nil
}
func
i64p
(
v
int64
)
*
int64
{
return
&
v
}
func
newOpenAIRecordUsageServiceForTest
(
usageRepo
UsageLogRepository
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
,
rateRepo
UserGroupRateRepository
)
*
OpenAIGatewayService
{
cfg
:=
&
config
.
Config
{}
cfg
.
Default
.
RateMultiplier
=
1.1
return
&
OpenAIGatewayService
{
usageLogRepo
:
usageRepo
,
userRepo
:
userRepo
,
userSubRepo
:
subRepo
,
cfg
:
cfg
,
billingService
:
NewBillingService
(
cfg
,
nil
),
billingCacheService
:
&
BillingCacheService
{},
deferredService
:
&
DeferredService
{},
userGroupRateResolver
:
newUserGroupRateResolver
(
rateRepo
,
nil
,
resolveUserGroupRateCacheTTL
(
cfg
),
nil
,
"service.openai_gateway.test"
,
),
}
}
func
expectedOpenAICost
(
t
*
testing
.
T
,
svc
*
OpenAIGatewayService
,
model
string
,
usage
OpenAIUsage
,
multiplier
float64
)
*
CostBreakdown
{
t
.
Helper
()
cost
,
err
:=
svc
.
billingService
.
CalculateCost
(
model
,
UsageTokens
{
InputTokens
:
max
(
usage
.
InputTokens
-
usage
.
CacheReadInputTokens
,
0
),
OutputTokens
:
usage
.
OutputTokens
,
CacheCreationTokens
:
usage
.
CacheCreationInputTokens
,
CacheReadTokens
:
usage
.
CacheReadInputTokens
,
},
multiplier
)
require
.
NoError
(
t
,
err
)
return
cost
}
func
max
(
a
,
b
int
)
int
{
if
a
>
b
{
return
a
}
return
b
}
func
TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
11
)
groupRate
:=
1.4
userRate
:=
1.8
usage
:=
OpenAIUsage
{
InputTokens
:
15
,
OutputTokens
:
4
,
CacheReadInputTokens
:
3
}
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
true
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
rateRepo
:=
&
openAIUserGroupRateRepoStub
{
rate
:
&
userRate
}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
rateRepo
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_user_group_rate"
,
Usage
:
usage
,
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
1001
,
GroupID
:
i64p
(
groupID
),
Group
:
&
Group
{
ID
:
groupID
,
RateMultiplier
:
groupRate
,
},
},
User
:
&
User
{
ID
:
2001
},
Account
:
&
Account
{
ID
:
3001
},
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
rateRepo
.
calls
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
userRate
,
usageRepo
.
lastLog
.
RateMultiplier
)
require
.
Equal
(
t
,
12
,
usageRepo
.
lastLog
.
InputTokens
)
require
.
Equal
(
t
,
3
,
usageRepo
.
lastLog
.
CacheReadTokens
)
expected
:=
expectedOpenAICost
(
t
,
svc
,
"gpt-5.1"
,
usage
,
userRate
)
require
.
InDelta
(
t
,
expected
.
ActualCost
,
usageRepo
.
lastLog
.
ActualCost
,
1e-12
)
require
.
InDelta
(
t
,
expected
.
ActualCost
,
userRepo
.
lastAmount
,
1e-12
)
require
.
Equal
(
t
,
1
,
userRepo
.
deductCalls
)
}
func
TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
12
)
groupRate
:=
1.6
usage
:=
OpenAIUsage
{
InputTokens
:
10
,
OutputTokens
:
5
,
CacheReadInputTokens
:
2
}
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
true
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
rateRepo
:=
&
openAIUserGroupRateRepoStub
{
err
:
errors
.
New
(
"db unavailable"
)}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
rateRepo
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_group_default_on_error"
,
Usage
:
usage
,
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
1002
,
GroupID
:
i64p
(
groupID
),
Group
:
&
Group
{
ID
:
groupID
,
RateMultiplier
:
groupRate
,
},
},
User
:
&
User
{
ID
:
2002
},
Account
:
&
Account
{
ID
:
3002
},
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
rateRepo
.
calls
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
groupRate
,
usageRepo
.
lastLog
.
RateMultiplier
)
expected
:=
expectedOpenAICost
(
t
,
svc
,
"gpt-5.1"
,
usage
,
groupRate
)
require
.
InDelta
(
t
,
expected
.
ActualCost
,
userRepo
.
lastAmount
,
1e-12
)
}
func
TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
13
)
groupRate
:=
1.25
usage
:=
OpenAIUsage
{
InputTokens
:
9
,
OutputTokens
:
4
,
CacheReadInputTokens
:
1
}
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
true
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
svc
.
userGroupRateResolver
=
nil
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_group_default_nil_resolver"
,
Usage
:
usage
,
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
1003
,
GroupID
:
i64p
(
groupID
),
Group
:
&
Group
{
ID
:
groupID
,
RateMultiplier
:
groupRate
,
},
},
User
:
&
User
{
ID
:
2003
},
Account
:
&
Account
{
ID
:
3003
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
groupRate
,
usageRepo
.
lastLog
.
RateMultiplier
)
}
func
TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
false
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_duplicate"
,
Usage
:
OpenAIUsage
{
InputTokens
:
8
,
OutputTokens
:
4
,
},
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
1004
},
User
:
&
User
{
ID
:
2004
},
Account
:
&
Account
{
ID
:
3004
},
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
usageRepo
.
calls
)
require
.
Equal
(
t
,
0
,
userRepo
.
deductCalls
)
require
.
Equal
(
t
,
0
,
subRepo
.
incrementCalls
)
}
func
TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured
(
t
*
testing
.
T
)
{
usage
:=
OpenAIUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
CacheReadInputTokens
:
2
}
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
true
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
quotaSvc
:=
&
openAIRecordUsageAPIKeyQuotaStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_quota_update"
,
Usage
:
usage
,
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
1005
,
Quota
:
100
,
},
User
:
&
User
{
ID
:
2005
},
Account
:
&
Account
{
ID
:
3005
},
APIKeyService
:
quotaSvc
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
quotaSvc
.
quotaCalls
)
require
.
Equal
(
t
,
0
,
quotaSvc
.
rateLimitCalls
)
expected
:=
expectedOpenAICost
(
t
,
svc
,
"gpt-5.1"
,
usage
,
1.1
)
require
.
InDelta
(
t
,
expected
.
ActualCost
,
quotaSvc
.
lastAmount
,
1e-12
)
}
func
TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
true
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_clamp_actual_input"
,
Usage
:
OpenAIUsage
{
InputTokens
:
2
,
OutputTokens
:
1
,
CacheReadInputTokens
:
5
,
},
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
1006
},
User
:
&
User
{
ID
:
2006
},
Account
:
&
Account
{
ID
:
3006
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
0
,
usageRepo
.
lastLog
.
InputTokens
)
}
backend/internal/service/openai_gateway_service.go
View file @
005d0c5f
...
...
@@ -245,23 +245,24 @@ type openAIWSRetryMetrics struct {
// OpenAIGatewayService handles OpenAI API gateway operations
type
OpenAIGatewayService
struct
{
accountRepo
AccountRepository
usageLogRepo
UsageLogRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
cache
GatewayCache
cfg
*
config
.
Config
codexDetector
CodexClientRestrictionDetector
schedulerSnapshot
*
SchedulerSnapshotService
concurrencyService
*
ConcurrencyService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
openAITokenProvider
*
OpenAITokenProvider
toolCorrector
*
CodexToolCorrector
openaiWSResolver
OpenAIWSProtocolResolver
accountRepo
AccountRepository
usageLogRepo
UsageLogRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
cache
GatewayCache
cfg
*
config
.
Config
codexDetector
CodexClientRestrictionDetector
schedulerSnapshot
*
SchedulerSnapshotService
concurrencyService
*
ConcurrencyService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
userGroupRateResolver
*
userGroupRateResolver
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
openAITokenProvider
*
OpenAITokenProvider
toolCorrector
*
CodexToolCorrector
openaiWSResolver
OpenAIWSProtocolResolver
openaiWSPoolOnce
sync
.
Once
openaiWSStateStoreOnce
sync
.
Once
...
...
@@ -284,6 +285,7 @@ func NewOpenAIGatewayService(
usageLogRepo
UsageLogRepository
,
userRepo
UserRepository
,
userSubRepo
UserSubscriptionRepository
,
userGroupRateRepo
UserGroupRateRepository
,
cache
GatewayCache
,
cfg
*
config
.
Config
,
schedulerSnapshot
*
SchedulerSnapshotService
,
...
...
@@ -296,18 +298,25 @@ func NewOpenAIGatewayService(
openAITokenProvider
*
OpenAITokenProvider
,
)
*
OpenAIGatewayService
{
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
accountRepo
,
usageLogRepo
:
usageLogRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
cache
:
cache
,
cfg
:
cfg
,
codexDetector
:
NewOpenAICodexClientRestrictionDetector
(
cfg
),
schedulerSnapshot
:
schedulerSnapshot
,
concurrencyService
:
concurrencyService
,
billingService
:
billingService
,
rateLimitService
:
rateLimitService
,
billingCacheService
:
billingCacheService
,
accountRepo
:
accountRepo
,
usageLogRepo
:
usageLogRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
cache
:
cache
,
cfg
:
cfg
,
codexDetector
:
NewOpenAICodexClientRestrictionDetector
(
cfg
),
schedulerSnapshot
:
schedulerSnapshot
,
concurrencyService
:
concurrencyService
,
billingService
:
billingService
,
rateLimitService
:
rateLimitService
,
billingCacheService
:
billingCacheService
,
userGroupRateResolver
:
newUserGroupRateResolver
(
userGroupRateRepo
,
nil
,
resolveUserGroupRateCacheTTL
(
cfg
),
nil
,
"service.openai_gateway"
,
),
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
openAITokenProvider
:
openAITokenProvider
,
...
...
@@ -3261,6 +3270,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
// Correct tool calls in final response
body
=
s
.
correctToolCallsInResponseBody
(
body
)
}
else
{
terminalType
,
terminalPayload
,
terminalOK
:=
extractOpenAISSETerminalEvent
(
bodyText
)
if
terminalOK
&&
terminalType
==
"response.failed"
{
msg
:=
extractOpenAISSEErrorMessage
(
terminalPayload
)
if
msg
==
""
{
msg
=
"Upstream compact response failed"
}
return
nil
,
s
.
writeOpenAINonStreamingProtocolError
(
resp
,
c
,
msg
)
}
usage
=
s
.
parseSSEUsageFromBody
(
bodyText
)
if
originalModel
!=
mappedModel
{
bodyText
=
s
.
replaceModelInSSEBody
(
bodyText
,
mappedModel
,
originalModel
)
...
...
@@ -3282,6 +3299,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
return
usage
,
nil
}
func
extractOpenAISSETerminalEvent
(
body
string
)
(
string
,
[]
byte
,
bool
)
{
lines
:=
strings
.
Split
(
body
,
"
\n
"
)
for
_
,
line
:=
range
lines
{
data
,
ok
:=
extractOpenAISSEDataLine
(
line
)
if
!
ok
||
data
==
""
||
data
==
"[DONE]"
{
continue
}
eventType
:=
strings
.
TrimSpace
(
gjson
.
Get
(
data
,
"type"
)
.
String
())
switch
eventType
{
case
"response.completed"
,
"response.done"
,
"response.failed"
:
return
eventType
,
[]
byte
(
data
),
true
}
}
return
""
,
nil
,
false
}
func
extractOpenAISSEErrorMessage
(
payload
[]
byte
)
string
{
if
len
(
payload
)
==
0
{
return
""
}
for
_
,
path
:=
range
[]
string
{
"response.error.message"
,
"error.message"
,
"message"
}
{
if
msg
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
payload
,
path
)
.
String
());
msg
!=
""
{
return
sanitizeUpstreamErrorMessage
(
msg
)
}
}
return
sanitizeUpstreamErrorMessage
(
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
payload
)))
}
func
(
s
*
OpenAIGatewayService
)
writeOpenAINonStreamingProtocolError
(
resp
*
http
.
Response
,
c
*
gin
.
Context
,
message
string
)
error
{
message
=
sanitizeUpstreamErrorMessage
(
strings
.
TrimSpace
(
message
))
if
message
==
""
{
message
=
"Upstream returned an invalid non-streaming response"
}
setOpsUpstreamError
(
c
,
http
.
StatusBadGateway
,
message
,
""
)
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
responseHeaderFilter
)
c
.
Writer
.
Header
()
.
Set
(
"Content-Type"
,
"application/json; charset=utf-8"
)
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
message
,
},
})
return
fmt
.
Errorf
(
"non-streaming openai protocol error: %s"
,
message
)
}
func
extractCodexFinalResponse
(
body
string
)
([]
byte
,
bool
)
{
lines
:=
strings
.
Split
(
body
,
"
\n
"
)
for
_
,
line
:=
range
lines
{
...
...
@@ -3413,7 +3475,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Get rate multiplier
multiplier
:=
s
.
cfg
.
Default
.
RateMultiplier
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
multiplier
=
apiKey
.
Group
.
RateMultiplier
resolver
:=
s
.
userGroupRateResolver
if
resolver
==
nil
{
resolver
=
newUserGroupRateResolver
(
nil
,
nil
,
resolveUserGroupRateCacheTTL
(
s
.
cfg
),
nil
,
"service.openai_gateway"
)
}
multiplier
=
resolver
.
Resolve
(
ctx
,
user
.
ID
,
*
apiKey
.
GroupID
,
apiKey
.
Group
.
RateMultiplier
)
}
cost
,
err
:=
s
.
billingService
.
CalculateCost
(
result
.
Model
,
tokens
,
multiplier
)
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
005d0c5f
...
...
@@ -1576,3 +1576,27 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
require
.
Contains
(
t
,
rec
.
Header
()
.
Get
(
"Content-Type"
),
"text/event-stream"
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
`data: {"type":"response.in_progress"`
)
}
func
TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
svc
:=
&
OpenAIGatewayService
{
cfg
:
&
config
.
Config
{}}
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
}
body
:=
[]
byte
(
strings
.
Join
([]
string
{
`data: {"type":"response.failed","error":{"message":"upstream rejected request"}}`
,
`data: [DONE]`
,
},
"
\n
"
))
usage
,
err
:=
svc
.
handleOAuthSSEToJSON
(
resp
,
c
,
body
,
"gpt-4o"
,
"gpt-4o"
)
require
.
Nil
(
t
,
usage
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
http
.
StatusBadGateway
,
rec
.
Code
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"upstream rejected request"
)
require
.
Contains
(
t
,
rec
.
Header
()
.
Get
(
"Content-Type"
),
"application/json"
)
}
backend/internal/service/openai_ws_protocol_forward_test.go
View file @
005d0c5f
...
...
@@ -391,6 +391,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil
,
nil
,
nil
,
nil
,
cfg
,
nil
,
nil
,
...
...
backend/internal/service/user_group_rate_resolver.go
0 → 100644
View file @
005d0c5f
package
service
import
(
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
gocache
"github.com/patrickmn/go-cache"
"golang.org/x/sync/singleflight"
)
type
userGroupRateResolver
struct
{
repo
UserGroupRateRepository
cache
*
gocache
.
Cache
cacheTTL
time
.
Duration
sf
*
singleflight
.
Group
logComponent
string
}
func
newUserGroupRateResolver
(
repo
UserGroupRateRepository
,
cache
*
gocache
.
Cache
,
cacheTTL
time
.
Duration
,
sf
*
singleflight
.
Group
,
logComponent
string
)
*
userGroupRateResolver
{
if
cacheTTL
<=
0
{
cacheTTL
=
defaultUserGroupRateCacheTTL
}
if
cache
==
nil
{
cache
=
gocache
.
New
(
cacheTTL
,
time
.
Minute
)
}
if
logComponent
==
""
{
logComponent
=
"service.gateway"
}
if
sf
==
nil
{
sf
=
&
singleflight
.
Group
{}
}
return
&
userGroupRateResolver
{
repo
:
repo
,
cache
:
cache
,
cacheTTL
:
cacheTTL
,
sf
:
sf
,
logComponent
:
logComponent
,
}
}
func
(
r
*
userGroupRateResolver
)
Resolve
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
groupDefaultMultiplier
float64
)
float64
{
if
r
==
nil
||
userID
<=
0
||
groupID
<=
0
{
return
groupDefaultMultiplier
}
key
:=
fmt
.
Sprintf
(
"%d:%d"
,
userID
,
groupID
)
if
r
.
cache
!=
nil
{
if
cached
,
ok
:=
r
.
cache
.
Get
(
key
);
ok
{
if
multiplier
,
castOK
:=
cached
.
(
float64
);
castOK
{
userGroupRateCacheHitTotal
.
Add
(
1
)
return
multiplier
}
}
}
if
r
.
repo
==
nil
{
return
groupDefaultMultiplier
}
userGroupRateCacheMissTotal
.
Add
(
1
)
value
,
err
,
shared
:=
r
.
sf
.
Do
(
key
,
func
()
(
any
,
error
)
{
if
r
.
cache
!=
nil
{
if
cached
,
ok
:=
r
.
cache
.
Get
(
key
);
ok
{
if
multiplier
,
castOK
:=
cached
.
(
float64
);
castOK
{
userGroupRateCacheHitTotal
.
Add
(
1
)
return
multiplier
,
nil
}
}
}
userGroupRateCacheLoadTotal
.
Add
(
1
)
userRate
,
repoErr
:=
r
.
repo
.
GetByUserAndGroup
(
ctx
,
userID
,
groupID
)
if
repoErr
!=
nil
{
return
nil
,
repoErr
}
multiplier
:=
groupDefaultMultiplier
if
userRate
!=
nil
{
multiplier
=
*
userRate
}
if
r
.
cache
!=
nil
{
r
.
cache
.
Set
(
key
,
multiplier
,
r
.
cacheTTL
)
}
return
multiplier
,
nil
})
if
shared
{
userGroupRateCacheSFSharedTotal
.
Add
(
1
)
}
if
err
!=
nil
{
userGroupRateCacheFallbackTotal
.
Add
(
1
)
logger
.
LegacyPrintf
(
r
.
logComponent
,
"get user group rate failed, fallback to group default: user=%d group=%d err=%v"
,
userID
,
groupID
,
err
)
return
groupDefaultMultiplier
}
multiplier
,
ok
:=
value
.
(
float64
)
if
!
ok
{
userGroupRateCacheFallbackTotal
.
Add
(
1
)
return
groupDefaultMultiplier
}
return
multiplier
}
backend/internal/service/user_group_rate_resolver_test.go
0 → 100644
View file @
005d0c5f
package
service
import
(
"context"
"testing"
"time"
gocache
"github.com/patrickmn/go-cache"
"github.com/stretchr/testify/require"
)
type
userGroupRateResolverRepoStub
struct
{
UserGroupRateRepository
rate
*
float64
err
error
calls
int
}
func
(
s
*
userGroupRateResolverRepoStub
)
GetByUserAndGroup
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
float64
,
error
)
{
s
.
calls
++
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
return
s
.
rate
,
nil
}
func
TestNewUserGroupRateResolver_Defaults
(
t
*
testing
.
T
)
{
resolver
:=
newUserGroupRateResolver
(
nil
,
nil
,
0
,
nil
,
""
)
require
.
NotNil
(
t
,
resolver
)
require
.
NotNil
(
t
,
resolver
.
cache
)
require
.
Equal
(
t
,
defaultUserGroupRateCacheTTL
,
resolver
.
cacheTTL
)
require
.
NotNil
(
t
,
resolver
.
sf
)
require
.
Equal
(
t
,
"service.gateway"
,
resolver
.
logComponent
)
}
func
TestUserGroupRateResolverResolve_FallbackForNilResolverAndInvalidIDs
(
t
*
testing
.
T
)
{
var
nilResolver
*
userGroupRateResolver
require
.
Equal
(
t
,
1.4
,
nilResolver
.
Resolve
(
context
.
Background
(),
101
,
202
,
1.4
))
resolver
:=
newUserGroupRateResolver
(
nil
,
nil
,
time
.
Second
,
nil
,
"service.test"
)
require
.
Equal
(
t
,
1.4
,
resolver
.
Resolve
(
context
.
Background
(),
0
,
202
,
1.4
))
require
.
Equal
(
t
,
1.4
,
resolver
.
Resolve
(
context
.
Background
(),
101
,
0
,
1.4
))
}
func
TestUserGroupRateResolverResolve_InvalidCacheEntryLoadsRepoAndCaches
(
t
*
testing
.
T
)
{
resetGatewayHotpathStatsForTest
()
rate
:=
1.7
repo
:=
&
userGroupRateResolverRepoStub
{
rate
:
&
rate
}
cache
:=
gocache
.
New
(
time
.
Minute
,
time
.
Minute
)
cache
.
Set
(
"101:202"
,
"bad-cache"
,
time
.
Minute
)
resolver
:=
newUserGroupRateResolver
(
repo
,
cache
,
time
.
Minute
,
nil
,
"service.test"
)
got
:=
resolver
.
Resolve
(
context
.
Background
(),
101
,
202
,
1.2
)
require
.
Equal
(
t
,
rate
,
got
)
require
.
Equal
(
t
,
1
,
repo
.
calls
)
cached
,
ok
:=
cache
.
Get
(
"101:202"
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
rate
,
cached
)
hit
,
miss
,
load
,
_
,
fallback
:=
GatewayUserGroupRateCacheStats
()
require
.
Equal
(
t
,
int64
(
0
),
hit
)
require
.
Equal
(
t
,
int64
(
1
),
miss
)
require
.
Equal
(
t
,
int64
(
1
),
load
)
require
.
Equal
(
t
,
int64
(
0
),
fallback
)
}
func
TestGatewayServiceGetUserGroupRateMultiplier_FallbacksAndUsesExistingResolver
(
t
*
testing
.
T
)
{
var
nilSvc
*
GatewayService
require
.
Equal
(
t
,
1.3
,
nilSvc
.
getUserGroupRateMultiplier
(
context
.
Background
(),
101
,
202
,
1.3
))
rate
:=
1.9
repo
:=
&
userGroupRateResolverRepoStub
{
rate
:
&
rate
}
resolver
:=
newUserGroupRateResolver
(
repo
,
nil
,
time
.
Minute
,
nil
,
"service.gateway"
)
svc
:=
&
GatewayService
{
userGroupRateResolver
:
resolver
}
got
:=
svc
.
getUserGroupRateMultiplier
(
context
.
Background
(),
101
,
202
,
1.2
)
require
.
Equal
(
t
,
rate
,
got
)
require
.
Equal
(
t
,
1
,
repo
.
calls
)
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment