Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6cc7f997
Commit
6cc7f997
authored
Jan 02, 2026
by
song
Browse files
merge: 合并 upstream/main
parents
95d09f60
106e59b7
Changes
135
Show whitespace changes
Inline
Side-by-side
backend/internal/repository/concurrency_cache_benchmark_test.go
View file @
6cc7f997
...
...
@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) {
_
=
rdb
.
Close
()
}()
cache
,
_
:=
NewConcurrencyCache
(
rdb
,
benchSlotTTLMinutes
)
.
(
*
concurrencyCache
)
cache
,
_
:=
NewConcurrencyCache
(
rdb
,
benchSlotTTLMinutes
,
int
(
benchSlotTTL
.
Seconds
())
)
.
(
*
concurrencyCache
)
ctx
:=
context
.
Background
()
for
_
,
size
:=
range
[]
int
{
10
,
100
,
1000
}
{
...
...
backend/internal/repository/concurrency_cache_integration_test.go
View file @
6cc7f997
...
...
@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func
(
s
*
ConcurrencyCacheSuite
)
SetupTest
()
{
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
cache
=
NewConcurrencyCache
(
s
.
rdb
,
testSlotTTLMinutes
)
s
.
cache
=
NewConcurrencyCache
(
s
.
rdb
,
testSlotTTLMinutes
,
int
(
testSlotTTL
.
Seconds
())
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_AcquireAndRelease
()
{
...
...
@@ -218,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require
.
GreaterOrEqual
(
s
.
T
(),
val
,
0
,
"expected non-negative wait count"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountWaitQueue_IncrementAndDecrement
()
{
accountID
:=
int64
(
30
)
waitKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountWaitKeyPrefix
,
accountID
)
ok
,
err
:=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
accountID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementAccountWaitCount 1"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
accountID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementAccountWaitCount 2"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
accountID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementAccountWaitCount 3"
)
require
.
False
(
s
.
T
(),
ok
,
"expected account wait increment over max to fail"
)
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
waitKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL account waitKey"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
testSlotTTL
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementAccountWaitCount
(
s
.
ctx
,
accountID
),
"DecrementAccountWaitCount"
)
val
,
err
:=
s
.
rdb
.
Get
(
s
.
ctx
,
waitKey
)
.
Int
()
if
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
require
.
NoError
(
s
.
T
(),
err
,
"Get waitKey"
)
}
require
.
Equal
(
s
.
T
(),
1
,
val
,
"expected account wait count 1"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountWaitQueue_DecrementNoNegative
()
{
accountID
:=
int64
(
301
)
waitKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountWaitKeyPrefix
,
accountID
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementAccountWaitCount
(
s
.
ctx
,
accountID
),
"DecrementAccountWaitCount on non-existent key"
)
val
,
err
:=
s
.
rdb
.
Get
(
s
.
ctx
,
waitKey
)
.
Int
()
if
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
require
.
NoError
(
s
.
T
(),
err
,
"Get waitKey"
)
}
require
.
GreaterOrEqual
(
s
.
T
(),
val
,
0
,
"expected non-negative account wait count after decrement on empty"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountConcurrency_Missing
()
{
// When no slots exist, GetAccountConcurrency should return 0
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
999
)
...
...
@@ -232,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require
.
Equal
(
s
.
T
(),
0
,
cur
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountsLoadBatch
()
{
s
.
T
()
.
Skip
(
"TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI"
)
// Setup: Create accounts with different load states
account1
:=
int64
(
100
)
account2
:=
int64
(
101
)
account3
:=
int64
(
102
)
// Account 1: 2/3 slots used, 1 waiting
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
account1
,
3
,
"req1"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
account1
,
3
,
"req2"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
account1
,
5
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Account 2: 1/2 slots used, 0 waiting
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
account2
,
2
,
"req3"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts
:=
[]
service
.
AccountWithConcurrency
{
{
ID
:
account1
,
MaxConcurrency
:
3
},
{
ID
:
account2
,
MaxConcurrency
:
2
},
{
ID
:
account3
,
MaxConcurrency
:
1
},
}
loadMap
,
err
:=
s
.
cache
.
GetAccountsLoadBatch
(
s
.
ctx
,
accounts
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Len
(
s
.
T
(),
loadMap
,
3
)
// Verify account1: (2 + 1) / 3 = 100%
load1
:=
loadMap
[
account1
]
require
.
NotNil
(
s
.
T
(),
load1
)
require
.
Equal
(
s
.
T
(),
account1
,
load1
.
AccountID
)
require
.
Equal
(
s
.
T
(),
2
,
load1
.
CurrentConcurrency
)
require
.
Equal
(
s
.
T
(),
1
,
load1
.
WaitingCount
)
require
.
Equal
(
s
.
T
(),
100
,
load1
.
LoadRate
)
// Verify account2: (1 + 0) / 2 = 50%
load2
:=
loadMap
[
account2
]
require
.
NotNil
(
s
.
T
(),
load2
)
require
.
Equal
(
s
.
T
(),
account2
,
load2
.
AccountID
)
require
.
Equal
(
s
.
T
(),
1
,
load2
.
CurrentConcurrency
)
require
.
Equal
(
s
.
T
(),
0
,
load2
.
WaitingCount
)
require
.
Equal
(
s
.
T
(),
50
,
load2
.
LoadRate
)
// Verify account3: (0 + 0) / 1 = 0%
load3
:=
loadMap
[
account3
]
require
.
NotNil
(
s
.
T
(),
load3
)
require
.
Equal
(
s
.
T
(),
account3
,
load3
.
AccountID
)
require
.
Equal
(
s
.
T
(),
0
,
load3
.
CurrentConcurrency
)
require
.
Equal
(
s
.
T
(),
0
,
load3
.
WaitingCount
)
require
.
Equal
(
s
.
T
(),
0
,
load3
.
LoadRate
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountsLoadBatch_Empty
()
{
// Test with empty account list
loadMap
,
err
:=
s
.
cache
.
GetAccountsLoadBatch
(
s
.
ctx
,
[]
service
.
AccountWithConcurrency
{})
require
.
NoError
(
s
.
T
(),
err
)
require
.
Empty
(
s
.
T
(),
loadMap
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestCleanupExpiredAccountSlots
()
{
accountID
:=
int64
(
200
)
slotKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountSlotKeyPrefix
,
accountID
)
// Acquire 3 slots
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req1"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req2"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req3"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Verify 3 slots exist
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
3
,
cur
)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now
:=
time
.
Now
()
.
Unix
()
expiredTime
:=
now
-
int64
(
testSlotTTL
.
Seconds
())
-
10
// 10 seconds past TTL
err
=
s
.
rdb
.
ZAdd
(
s
.
ctx
,
slotKey
,
redis
.
Z
{
Score
:
float64
(
expiredTime
),
Member
:
"req1"
})
.
Err
()
require
.
NoError
(
s
.
T
(),
err
)
err
=
s
.
rdb
.
ZAdd
(
s
.
ctx
,
slotKey
,
redis
.
Z
{
Score
:
float64
(
expiredTime
),
Member
:
"req2"
})
.
Err
()
require
.
NoError
(
s
.
T
(),
err
)
// Run cleanup
err
=
s
.
cache
.
CleanupExpiredAccountSlots
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
// Verify only 1 slot remains (req3)
cur
,
err
=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
1
,
cur
)
// Verify req3 still exists
members
,
err
:=
s
.
rdb
.
ZRange
(
s
.
ctx
,
slotKey
,
0
,
-
1
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
)
require
.
Len
(
s
.
T
(),
members
,
1
)
require
.
Equal
(
s
.
T
(),
"req3"
,
members
[
0
])
}
func
(
s
*
ConcurrencyCacheSuite
)
TestCleanupExpiredAccountSlots_NoExpired
()
{
accountID
:=
int64
(
201
)
// Acquire 2 fresh slots
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req1"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req2"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Run cleanup (should not remove anything)
err
=
s
.
cache
.
CleanupExpiredAccountSlots
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
// Verify both slots still exist
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
2
,
cur
)
}
func
TestConcurrencyCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
ConcurrencyCacheSuite
))
}
backend/internal/repository/fixtures_integration_test.go
View file @
6cc7f997
...
...
@@ -40,7 +40,6 @@ func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *servic
SetBalance
(
u
.
Balance
)
.
SetConcurrency
(
u
.
Concurrency
)
.
SetUsername
(
u
.
Username
)
.
SetWechat
(
u
.
Wechat
)
.
SetNotes
(
u
.
Notes
)
if
!
u
.
CreatedAt
.
IsZero
()
{
create
.
SetCreatedAt
(
u
.
CreatedAt
)
...
...
backend/internal/repository/migrations_runner.go
View file @
6cc7f997
...
...
@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if
existing
!=
checksum
{
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return
fmt
.
Errorf
(
"migration %s checksum mismatch (db=%s file=%s)"
,
name
,
existing
,
checksum
)
return
fmt
.
Errorf
(
"migration %s checksum mismatch (db=%s file=%s)
\n
"
+
"This means the migration file was modified after being applied to the database.
\n
"
+
"Solutions:
\n
"
+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s
\n
"
+
" 2. For new changes, create a new migration file instead of modifying existing ones
\n
"
+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments"
,
name
,
existing
,
checksum
,
name
,
name
,
)
}
continue
// 迁移已应用且校验和匹配,跳过
}
...
...
backend/internal/repository/migrations_schema_integration_test.go
View file @
6cc7f997
...
...
@@ -23,7 +23,6 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// users: columns required by repository queries
requireColumn
(
t
,
tx
,
"users"
,
"username"
,
"character varying"
,
100
,
false
)
requireColumn
(
t
,
tx
,
"users"
,
"wechat"
,
"character varying"
,
100
,
false
)
requireColumn
(
t
,
tx
,
"users"
,
"notes"
,
"text"
,
0
,
false
)
// accounts: schedulable and rate-limit fields
...
...
backend/internal/repository/user_attribute_repo.go
0 → 100644
View file @
6cc7f997
package
repository
import
(
"context"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// UserAttributeDefinitionRepository implementation
type
userAttributeDefinitionRepository
struct
{
client
*
dbent
.
Client
}
// NewUserAttributeDefinitionRepository creates a new repository instance
func
NewUserAttributeDefinitionRepository
(
client
*
dbent
.
Client
)
service
.
UserAttributeDefinitionRepository
{
return
&
userAttributeDefinitionRepository
{
client
:
client
}
}
func
(
r
*
userAttributeDefinitionRepository
)
Create
(
ctx
context
.
Context
,
def
*
service
.
UserAttributeDefinition
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
created
,
err
:=
client
.
UserAttributeDefinition
.
Create
()
.
SetKey
(
def
.
Key
)
.
SetName
(
def
.
Name
)
.
SetDescription
(
def
.
Description
)
.
SetType
(
string
(
def
.
Type
))
.
SetOptions
(
toEntOptions
(
def
.
Options
))
.
SetRequired
(
def
.
Required
)
.
SetValidation
(
toEntValidation
(
def
.
Validation
))
.
SetPlaceholder
(
def
.
Placeholder
)
.
SetEnabled
(
def
.
Enabled
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrAttributeKeyExists
)
}
def
.
ID
=
created
.
ID
def
.
DisplayOrder
=
created
.
DisplayOrder
def
.
CreatedAt
=
created
.
CreatedAt
def
.
UpdatedAt
=
created
.
UpdatedAt
return
nil
}
func
(
r
*
userAttributeDefinitionRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
UserAttributeDefinition
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
e
,
err
:=
client
.
UserAttributeDefinition
.
Query
()
.
Where
(
userattributedefinition
.
IDEQ
(
id
))
.
Only
(
ctx
)
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrAttributeDefinitionNotFound
,
nil
)
}
return
defEntityToService
(
e
),
nil
}
func
(
r
*
userAttributeDefinitionRepository
)
GetByKey
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
UserAttributeDefinition
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
e
,
err
:=
client
.
UserAttributeDefinition
.
Query
()
.
Where
(
userattributedefinition
.
KeyEQ
(
key
))
.
Only
(
ctx
)
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrAttributeDefinitionNotFound
,
nil
)
}
return
defEntityToService
(
e
),
nil
}
func
(
r
*
userAttributeDefinitionRepository
)
Update
(
ctx
context
.
Context
,
def
*
service
.
UserAttributeDefinition
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
updated
,
err
:=
client
.
UserAttributeDefinition
.
UpdateOneID
(
def
.
ID
)
.
SetName
(
def
.
Name
)
.
SetDescription
(
def
.
Description
)
.
SetType
(
string
(
def
.
Type
))
.
SetOptions
(
toEntOptions
(
def
.
Options
))
.
SetRequired
(
def
.
Required
)
.
SetValidation
(
toEntValidation
(
def
.
Validation
))
.
SetPlaceholder
(
def
.
Placeholder
)
.
SetDisplayOrder
(
def
.
DisplayOrder
)
.
SetEnabled
(
def
.
Enabled
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrAttributeDefinitionNotFound
,
service
.
ErrAttributeKeyExists
)
}
def
.
UpdatedAt
=
updated
.
UpdatedAt
return
nil
}
func
(
r
*
userAttributeDefinitionRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
UserAttributeDefinition
.
Delete
()
.
Where
(
userattributedefinition
.
IDEQ
(
id
))
.
Exec
(
ctx
)
return
translatePersistenceError
(
err
,
service
.
ErrAttributeDefinitionNotFound
,
nil
)
}
func
(
r
*
userAttributeDefinitionRepository
)
List
(
ctx
context
.
Context
,
enabledOnly
bool
)
([]
service
.
UserAttributeDefinition
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
q
:=
client
.
UserAttributeDefinition
.
Query
()
if
enabledOnly
{
q
=
q
.
Where
(
userattributedefinition
.
EnabledEQ
(
true
))
}
entities
,
err
:=
q
.
Order
(
dbent
.
Asc
(
userattributedefinition
.
FieldDisplayOrder
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
result
:=
make
([]
service
.
UserAttributeDefinition
,
0
,
len
(
entities
))
for
_
,
e
:=
range
entities
{
result
=
append
(
result
,
*
defEntityToService
(
e
))
}
return
result
,
nil
}
func
(
r
*
userAttributeDefinitionRepository
)
UpdateDisplayOrders
(
ctx
context
.
Context
,
orders
map
[
int64
]
int
)
error
{
tx
,
err
:=
r
.
client
.
Tx
(
ctx
)
if
err
!=
nil
{
return
err
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
for
id
,
order
:=
range
orders
{
if
_
,
err
:=
tx
.
UserAttributeDefinition
.
UpdateOneID
(
id
)
.
SetDisplayOrder
(
order
)
.
Save
(
ctx
);
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrAttributeDefinitionNotFound
,
nil
)
}
}
return
tx
.
Commit
()
}
func
(
r
*
userAttributeDefinitionRepository
)
ExistsByKey
(
ctx
context
.
Context
,
key
string
)
(
bool
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
return
client
.
UserAttributeDefinition
.
Query
()
.
Where
(
userattributedefinition
.
KeyEQ
(
key
))
.
Exist
(
ctx
)
}
// UserAttributeValueRepository implementation
type
userAttributeValueRepository
struct
{
client
*
dbent
.
Client
}
// NewUserAttributeValueRepository creates a new repository instance
func
NewUserAttributeValueRepository
(
client
*
dbent
.
Client
)
service
.
UserAttributeValueRepository
{
return
&
userAttributeValueRepository
{
client
:
client
}
}
func
(
r
*
userAttributeValueRepository
)
GetByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserAttributeValue
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
entities
,
err
:=
client
.
UserAttributeValue
.
Query
()
.
Where
(
userattributevalue
.
UserIDEQ
(
userID
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
result
:=
make
([]
service
.
UserAttributeValue
,
0
,
len
(
entities
))
for
_
,
e
:=
range
entities
{
result
=
append
(
result
,
service
.
UserAttributeValue
{
ID
:
e
.
ID
,
UserID
:
e
.
UserID
,
AttributeID
:
e
.
AttributeID
,
Value
:
e
.
Value
,
CreatedAt
:
e
.
CreatedAt
,
UpdatedAt
:
e
.
UpdatedAt
,
})
}
return
result
,
nil
}
func
(
r
*
userAttributeValueRepository
)
GetByUserIDs
(
ctx
context
.
Context
,
userIDs
[]
int64
)
([]
service
.
UserAttributeValue
,
error
)
{
if
len
(
userIDs
)
==
0
{
return
[]
service
.
UserAttributeValue
{},
nil
}
client
:=
clientFromContext
(
ctx
,
r
.
client
)
entities
,
err
:=
client
.
UserAttributeValue
.
Query
()
.
Where
(
userattributevalue
.
UserIDIn
(
userIDs
...
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
result
:=
make
([]
service
.
UserAttributeValue
,
0
,
len
(
entities
))
for
_
,
e
:=
range
entities
{
result
=
append
(
result
,
service
.
UserAttributeValue
{
ID
:
e
.
ID
,
UserID
:
e
.
UserID
,
AttributeID
:
e
.
AttributeID
,
Value
:
e
.
Value
,
CreatedAt
:
e
.
CreatedAt
,
UpdatedAt
:
e
.
UpdatedAt
,
})
}
return
result
,
nil
}
func
(
r
*
userAttributeValueRepository
)
UpsertBatch
(
ctx
context
.
Context
,
userID
int64
,
inputs
[]
service
.
UpdateUserAttributeInput
)
error
{
if
len
(
inputs
)
==
0
{
return
nil
}
tx
,
err
:=
r
.
client
.
Tx
(
ctx
)
if
err
!=
nil
{
return
err
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
for
_
,
input
:=
range
inputs
{
// Use upsert (ON CONFLICT DO UPDATE)
err
:=
tx
.
UserAttributeValue
.
Create
()
.
SetUserID
(
userID
)
.
SetAttributeID
(
input
.
AttributeID
)
.
SetValue
(
input
.
Value
)
.
OnConflictColumns
(
userattributevalue
.
FieldUserID
,
userattributevalue
.
FieldAttributeID
)
.
UpdateValue
()
.
UpdateUpdatedAt
()
.
Exec
(
ctx
)
if
err
!=
nil
{
return
err
}
}
return
tx
.
Commit
()
}
func
(
r
*
userAttributeValueRepository
)
DeleteByAttributeID
(
ctx
context
.
Context
,
attributeID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
UserAttributeValue
.
Delete
()
.
Where
(
userattributevalue
.
AttributeIDEQ
(
attributeID
))
.
Exec
(
ctx
)
return
err
}
func
(
r
*
userAttributeValueRepository
)
DeleteByUserID
(
ctx
context
.
Context
,
userID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
_
,
err
:=
client
.
UserAttributeValue
.
Delete
()
.
Where
(
userattributevalue
.
UserIDEQ
(
userID
))
.
Exec
(
ctx
)
return
err
}
// Helper functions for entity to service conversion
func
defEntityToService
(
e
*
dbent
.
UserAttributeDefinition
)
*
service
.
UserAttributeDefinition
{
if
e
==
nil
{
return
nil
}
return
&
service
.
UserAttributeDefinition
{
ID
:
e
.
ID
,
Key
:
e
.
Key
,
Name
:
e
.
Name
,
Description
:
e
.
Description
,
Type
:
service
.
UserAttributeType
(
e
.
Type
),
Options
:
toServiceOptions
(
e
.
Options
),
Required
:
e
.
Required
,
Validation
:
toServiceValidation
(
e
.
Validation
),
Placeholder
:
e
.
Placeholder
,
DisplayOrder
:
e
.
DisplayOrder
,
Enabled
:
e
.
Enabled
,
CreatedAt
:
e
.
CreatedAt
,
UpdatedAt
:
e
.
UpdatedAt
,
}
}
// Type conversion helpers (map types <-> service types)
func
toEntOptions
(
opts
[]
service
.
UserAttributeOption
)
[]
map
[
string
]
any
{
if
opts
==
nil
{
return
[]
map
[
string
]
any
{}
}
result
:=
make
([]
map
[
string
]
any
,
len
(
opts
))
for
i
,
o
:=
range
opts
{
result
[
i
]
=
map
[
string
]
any
{
"value"
:
o
.
Value
,
"label"
:
o
.
Label
}
}
return
result
}
func
toServiceOptions
(
opts
[]
map
[
string
]
any
)
[]
service
.
UserAttributeOption
{
if
opts
==
nil
{
return
[]
service
.
UserAttributeOption
{}
}
result
:=
make
([]
service
.
UserAttributeOption
,
len
(
opts
))
for
i
,
o
:=
range
opts
{
result
[
i
]
=
service
.
UserAttributeOption
{
Value
:
getString
(
o
,
"value"
),
Label
:
getString
(
o
,
"label"
),
}
}
return
result
}
func
toEntValidation
(
v
service
.
UserAttributeValidation
)
map
[
string
]
any
{
result
:=
map
[
string
]
any
{}
if
v
.
MinLength
!=
nil
{
result
[
"min_length"
]
=
*
v
.
MinLength
}
if
v
.
MaxLength
!=
nil
{
result
[
"max_length"
]
=
*
v
.
MaxLength
}
if
v
.
Min
!=
nil
{
result
[
"min"
]
=
*
v
.
Min
}
if
v
.
Max
!=
nil
{
result
[
"max"
]
=
*
v
.
Max
}
if
v
.
Pattern
!=
nil
{
result
[
"pattern"
]
=
*
v
.
Pattern
}
if
v
.
Message
!=
nil
{
result
[
"message"
]
=
*
v
.
Message
}
return
result
}
func
toServiceValidation
(
v
map
[
string
]
any
)
service
.
UserAttributeValidation
{
result
:=
service
.
UserAttributeValidation
{}
if
val
:=
getInt
(
v
,
"min_length"
);
val
!=
nil
{
result
.
MinLength
=
val
}
if
val
:=
getInt
(
v
,
"max_length"
);
val
!=
nil
{
result
.
MaxLength
=
val
}
if
val
:=
getInt
(
v
,
"min"
);
val
!=
nil
{
result
.
Min
=
val
}
if
val
:=
getInt
(
v
,
"max"
);
val
!=
nil
{
result
.
Max
=
val
}
if
val
:=
getStringPtr
(
v
,
"pattern"
);
val
!=
nil
{
result
.
Pattern
=
val
}
if
val
:=
getStringPtr
(
v
,
"message"
);
val
!=
nil
{
result
.
Message
=
val
}
return
result
}
// Helper functions for type conversion
func
getString
(
m
map
[
string
]
any
,
key
string
)
string
{
if
v
,
ok
:=
m
[
key
];
ok
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
return
s
}
}
return
""
}
func
getStringPtr
(
m
map
[
string
]
any
,
key
string
)
*
string
{
if
v
,
ok
:=
m
[
key
];
ok
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
return
&
s
}
}
return
nil
}
func
getInt
(
m
map
[
string
]
any
,
key
string
)
*
int
{
if
v
,
ok
:=
m
[
key
];
ok
{
switch
n
:=
v
.
(
type
)
{
case
int
:
return
&
n
case
int64
:
i
:=
int
(
n
)
return
&
i
case
float64
:
i
:=
int
(
n
)
return
&
i
}
}
return
nil
}
backend/internal/repository/user_repo.go
View file @
6cc7f997
...
...
@@ -9,6 +9,7 @@ import (
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbuser
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -50,7 +51,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
created
,
err
:=
txClient
.
User
.
Create
()
.
SetEmail
(
userIn
.
Email
)
.
SetUsername
(
userIn
.
Username
)
.
SetWechat
(
userIn
.
Wechat
)
.
SetNotes
(
userIn
.
Notes
)
.
SetPasswordHash
(
userIn
.
PasswordHash
)
.
SetRole
(
userIn
.
Role
)
.
...
...
@@ -133,7 +133,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
updated
,
err
:=
txClient
.
User
.
UpdateOneID
(
userIn
.
ID
)
.
SetEmail
(
userIn
.
Email
)
.
SetUsername
(
userIn
.
Username
)
.
SetWechat
(
userIn
.
Wechat
)
.
SetNotes
(
userIn
.
Notes
)
.
SetPasswordHash
(
userIn
.
PasswordHash
)
.
SetRole
(
userIn
.
Role
)
.
...
...
@@ -171,28 +170,38 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error {
}
func
(
r
*
userRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
r
.
ListWithFilters
(
ctx
,
params
,
""
,
""
,
""
)
return
r
.
ListWithFilters
(
ctx
,
params
,
service
.
UserListFilters
{}
)
}
func
(
r
*
userRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
role
,
search
string
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
userRepository
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
service
.
UserListFilters
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
q
:=
r
.
client
.
User
.
Query
()
if
s
tatus
!=
""
{
q
=
q
.
Where
(
dbuser
.
StatusEQ
(
s
tatus
))
if
filters
.
S
tatus
!=
""
{
q
=
q
.
Where
(
dbuser
.
StatusEQ
(
filters
.
S
tatus
))
}
if
r
ole
!=
""
{
q
=
q
.
Where
(
dbuser
.
RoleEQ
(
r
ole
))
if
filters
.
R
ole
!=
""
{
q
=
q
.
Where
(
dbuser
.
RoleEQ
(
filters
.
R
ole
))
}
if
s
earch
!=
""
{
if
filters
.
S
earch
!=
""
{
q
=
q
.
Where
(
dbuser
.
Or
(
dbuser
.
EmailContainsFold
(
search
),
dbuser
.
UsernameContainsFold
(
search
),
dbuser
.
WechatContainsFold
(
search
),
dbuser
.
EmailContainsFold
(
filters
.
Search
),
dbuser
.
UsernameContainsFold
(
filters
.
Search
),
),
)
}
// If attribute filters are specified, we need to filter by user IDs first
var
allowedUserIDs
[]
int64
if
len
(
filters
.
Attributes
)
>
0
{
allowedUserIDs
=
r
.
filterUsersByAttributes
(
ctx
,
filters
.
Attributes
)
if
len
(
allowedUserIDs
)
==
0
{
// No users match the attribute filters
return
[]
service
.
User
{},
paginationResultFromTotal
(
0
,
params
),
nil
}
q
=
q
.
Where
(
dbuser
.
IDIn
(
allowedUserIDs
...
))
}
total
,
err
:=
q
.
Clone
()
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
...
...
@@ -252,6 +261,59 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return
outUsers
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func
(
r
*
userRepository
)
filterUsersByAttributes
(
ctx
context
.
Context
,
attrs
map
[
int64
]
string
)
[]
int64
{
if
len
(
attrs
)
==
0
{
return
nil
}
// For each attribute filter, get the set of matching user IDs
// Then intersect all sets to get users matching ALL filters
var
resultSet
map
[
int64
]
struct
{}
first
:=
true
for
attrID
,
value
:=
range
attrs
{
// Query user_attribute_values for this attribute
values
,
err
:=
r
.
client
.
UserAttributeValue
.
Query
()
.
Where
(
userattributevalue
.
AttributeIDEQ
(
attrID
),
userattributevalue
.
ValueContainsFold
(
value
),
)
.
All
(
ctx
)
if
err
!=
nil
{
continue
}
currentSet
:=
make
(
map
[
int64
]
struct
{},
len
(
values
))
for
_
,
v
:=
range
values
{
currentSet
[
v
.
UserID
]
=
struct
{}{}
}
if
first
{
resultSet
=
currentSet
first
=
false
}
else
{
// Intersect with previous results
for
userID
:=
range
resultSet
{
if
_
,
ok
:=
currentSet
[
userID
];
!
ok
{
delete
(
resultSet
,
userID
)
}
}
}
// Early exit if no users match
if
len
(
resultSet
)
==
0
{
return
nil
}
}
result
:=
make
([]
int64
,
0
,
len
(
resultSet
))
for
userID
:=
range
resultSet
{
result
=
append
(
result
,
userID
)
}
return
result
}
func
(
r
*
userRepository
)
UpdateBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
n
,
err
:=
client
.
User
.
Update
()
.
Where
(
dbuser
.
IDEQ
(
id
))
.
AddBalance
(
amount
)
.
Save
(
ctx
)
...
...
backend/internal/repository/user_repo_integration_test.go
View file @
6cc7f997
...
...
@@ -166,7 +166,7 @@ func (s *UserRepoSuite) TestListWithFilters_Status() {
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"active@test.com"
,
Status
:
service
.
StatusActive
})
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"disabled@test.com"
,
Status
:
service
.
StatusDisabled
})
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
StatusActive
,
""
,
""
)
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
UserListFilters
{
Status
:
service
.
StatusActive
}
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
users
,
1
)
s
.
Require
()
.
Equal
(
service
.
StatusActive
,
users
[
0
]
.
Status
)
...
...
@@ -176,7 +176,7 @@ func (s *UserRepoSuite) TestListWithFilters_Role() {
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"user@test.com"
,
Role
:
service
.
RoleUser
})
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"admin@test.com"
,
Role
:
service
.
RoleAdmin
})
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
service
.
RoleAdmin
,
""
)
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
UserListFilters
{
Role
:
service
.
RoleAdmin
}
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
users
,
1
)
s
.
Require
()
.
Equal
(
service
.
RoleAdmin
,
users
[
0
]
.
Role
)
...
...
@@ -186,7 +186,7 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"alice@test.com"
,
Username
:
"Alice"
})
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"bob@test.com"
,
Username
:
"Bob"
})
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
"alice"
)
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
UserListFilters
{
Search
:
"alice"
}
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
users
,
1
)
s
.
Require
()
.
Contains
(
users
[
0
]
.
Email
,
"alice"
)
...
...
@@ -196,22 +196,12 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"u1@test.com"
,
Username
:
"JohnDoe"
})
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"u2@test.com"
,
Username
:
"JaneSmith"
})
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
"john"
)
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
UserListFilters
{
Search
:
"john"
}
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
users
,
1
)
s
.
Require
()
.
Equal
(
"JohnDoe"
,
users
[
0
]
.
Username
)
}
func
(
s
*
UserRepoSuite
)
TestListWithFilters_SearchByWechat
()
{
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"w1@test.com"
,
Wechat
:
"wx_hello"
})
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"w2@test.com"
,
Wechat
:
"wx_world"
})
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
"wx_hello"
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
users
,
1
)
s
.
Require
()
.
Equal
(
"wx_hello"
,
users
[
0
]
.
Wechat
)
}
func
(
s
*
UserRepoSuite
)
TestListWithFilters_LoadsActiveSubscriptions
()
{
user
:=
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"sub@test.com"
,
Status
:
service
.
StatusActive
})
groupActive
:=
s
.
mustCreateGroup
(
"g-sub-active"
)
...
...
@@ -226,7 +216,7 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
c
.
SetExpiresAt
(
time
.
Now
()
.
Add
(
-
1
*
time
.
Hour
))
})
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
""
,
""
,
"sub@"
)
users
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
UserListFilters
{
Search
:
"sub@"
}
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
Len
(
users
,
1
,
"expected 1 user"
)
s
.
Require
()
.
Len
(
users
[
0
]
.
Subscriptions
,
1
,
"expected 1 active subscription"
)
...
...
@@ -238,7 +228,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"a@example.com"
,
Username
:
"Alice"
,
Wechat
:
"wx_a"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
...
...
@@ -246,7 +235,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
target
:=
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"b@example.com"
,
Username
:
"Bob"
,
Wechat
:
"wx_b"
,
Role
:
service
.
RoleAdmin
,
Status
:
service
.
StatusActive
,
Balance
:
1
,
...
...
@@ -257,7 +245,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
Status
:
service
.
StatusDisabled
,
})
users
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
StatusActive
,
service
.
RoleAdmin
,
"b@"
)
users
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
},
service
.
UserListFilters
{
Status
:
service
.
StatusActive
,
Role
:
service
.
RoleAdmin
,
Search
:
"b@"
}
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
,
"ListWithFilters total mismatch"
)
s
.
Require
()
.
Len
(
users
,
1
,
"ListWithFilters len mismatch"
)
...
...
@@ -448,7 +436,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1
:=
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"a@example.com"
,
Username
:
"Alice"
,
Wechat
:
"wx_a"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
...
...
@@ -456,7 +443,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user2
:=
s
.
mustCreateUser
(
&
service
.
User
{
Email
:
"b@example.com"
,
Username
:
"Bob"
,
Wechat
:
"wx_b"
,
Role
:
service
.
RoleAdmin
,
Status
:
service
.
StatusActive
,
Balance
:
1
,
...
...
@@ -501,7 +487,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s
.
Require
()
.
Equal
(
user1
.
Concurrency
+
3
,
got5
.
Concurrency
)
params
:=
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
}
users
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
params
,
service
.
StatusActive
,
service
.
RoleAdmin
,
"b@"
)
users
,
page
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
params
,
service
.
UserListFilters
{
Status
:
service
.
StatusActive
,
Role
:
service
.
RoleAdmin
,
Search
:
"b@"
}
)
s
.
Require
()
.
NoError
(
err
,
"ListWithFilters"
)
s
.
Require
()
.
Equal
(
int64
(
1
),
page
.
Total
,
"ListWithFilters total mismatch"
)
s
.
Require
()
.
Len
(
users
,
1
,
"ListWithFilters len mismatch"
)
...
...
backend/internal/repository/wire.go
View file @
6cc7f997
...
...
@@ -15,7 +15,14 @@ import (
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
func
ProvideConcurrencyCache
(
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
)
service
.
ConcurrencyCache
{
return
NewConcurrencyCache
(
rdb
,
cfg
.
Gateway
.
ConcurrencySlotTTLMinutes
)
waitTTLSeconds
:=
int
(
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
.
Seconds
())
if
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
>
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
{
waitTTLSeconds
=
int
(
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
.
Seconds
())
}
if
waitTTLSeconds
<=
0
{
waitTTLSeconds
=
cfg
.
Gateway
.
ConcurrencySlotTTLMinutes
*
60
}
return
NewConcurrencyCache
(
rdb
,
cfg
.
Gateway
.
ConcurrencySlotTTLMinutes
,
waitTTLSeconds
)
}
// ProviderSet is the Wire provider set for all repositories
...
...
@@ -29,6 +36,8 @@ var ProviderSet = wire.NewSet(
NewUsageLogRepository
,
NewSettingRepository
,
NewUserSubscriptionRepository
,
NewUserAttributeDefinitionRepository
,
NewUserAttributeValueRepository
,
// Cache implementations
NewGatewayCache
,
...
...
backend/internal/server/api_contract_test.go
View file @
6cc7f997
...
...
@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1,
"email": "alice@example.com",
"username": "alice",
"wechat": "wx_alice",
"notes": "hello",
"role": "user",
"balance": 12.5,
...
...
@@ -348,7 +347,6 @@ func newContractDeps(t *testing.T) *contractDeps {
ID
:
1
,
Email
:
"alice@example.com"
,
Username
:
"alice"
,
Wechat
:
"wx_alice"
,
Notes
:
"hello"
,
Role
:
service
.
RoleUser
,
Balance
:
12.5
,
...
...
@@ -503,7 +501,7 @@ func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationPar
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUserRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
role
,
search
string
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
r
*
stubUserRepo
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
service
.
UserListFilters
)
([]
service
.
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/routes/admin.go
View file @
6cc7f997
...
...
@@ -54,6 +54,9 @@ func RegisterAdminRoutes(
// 使用记录管理
registerUsageRoutes
(
admin
,
h
)
// 用户属性管理
registerUserAttributeRoutes
(
admin
,
h
)
}
}
...
...
@@ -82,6 +85,10 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users
.
POST
(
"/:id/balance"
,
h
.
Admin
.
User
.
UpdateBalance
)
users
.
GET
(
"/:id/api-keys"
,
h
.
Admin
.
User
.
GetUserAPIKeys
)
users
.
GET
(
"/:id/usage"
,
h
.
Admin
.
User
.
GetUserUsage
)
// User attribute values
users
.
GET
(
"/:id/attributes"
,
h
.
Admin
.
UserAttribute
.
GetUserAttributes
)
users
.
PUT
(
"/:id/attributes"
,
h
.
Admin
.
UserAttribute
.
UpdateUserAttributes
)
}
}
...
...
@@ -110,6 +117,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts
.
DELETE
(
"/:id"
,
h
.
Admin
.
Account
.
Delete
)
accounts
.
POST
(
"/:id/test"
,
h
.
Admin
.
Account
.
Test
)
accounts
.
POST
(
"/:id/refresh"
,
h
.
Admin
.
Account
.
Refresh
)
accounts
.
POST
(
"/:id/refresh-tier"
,
h
.
Admin
.
Account
.
RefreshTier
)
accounts
.
GET
(
"/:id/stats"
,
h
.
Admin
.
Account
.
GetStats
)
accounts
.
POST
(
"/:id/clear-error"
,
h
.
Admin
.
Account
.
ClearError
)
accounts
.
GET
(
"/:id/usage"
,
h
.
Admin
.
Account
.
GetUsage
)
...
...
@@ -119,6 +127,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts
.
GET
(
"/:id/models"
,
h
.
Admin
.
Account
.
GetAvailableModels
)
accounts
.
POST
(
"/batch"
,
h
.
Admin
.
Account
.
BatchCreate
)
accounts
.
POST
(
"/batch-update-credentials"
,
h
.
Admin
.
Account
.
BatchUpdateCredentials
)
accounts
.
POST
(
"/batch-refresh-tier"
,
h
.
Admin
.
Account
.
BatchRefreshTier
)
accounts
.
POST
(
"/bulk-update"
,
h
.
Admin
.
Account
.
BulkUpdate
)
// Claude OAuth routes
...
...
@@ -242,3 +251,15 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage
.
GET
(
"/search-api-keys"
,
h
.
Admin
.
Usage
.
SearchApiKeys
)
}
}
func
registerUserAttributeRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
attrs
:=
admin
.
Group
(
"/user-attributes"
)
{
attrs
.
GET
(
""
,
h
.
Admin
.
UserAttribute
.
ListDefinitions
)
attrs
.
POST
(
""
,
h
.
Admin
.
UserAttribute
.
CreateDefinition
)
attrs
.
POST
(
"/batch"
,
h
.
Admin
.
UserAttribute
.
GetBatchUserAttributes
)
attrs
.
PUT
(
"/reorder"
,
h
.
Admin
.
UserAttribute
.
ReorderDefinitions
)
attrs
.
PUT
(
"/:id"
,
h
.
Admin
.
UserAttribute
.
UpdateDefinition
)
attrs
.
DELETE
(
"/:id"
,
h
.
Admin
.
UserAttribute
.
DeleteDefinition
)
}
}
backend/internal/service/account.go
View file @
6cc7f997
...
...
@@ -3,6 +3,7 @@ package service
import
(
"encoding/json"
"strconv"
"strings"
"time"
)
...
...
@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return
a
.
Platform
==
PlatformGemini
}
func
(
a
*
Account
)
GeminiOAuthType
()
string
{
if
a
.
Platform
!=
PlatformGemini
||
a
.
Type
!=
AccountTypeOAuth
{
return
""
}
oauthType
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"oauth_type"
))
if
oauthType
==
""
&&
strings
.
TrimSpace
(
a
.
GetCredential
(
"project_id"
))
!=
""
{
return
"code_assist"
}
return
oauthType
}
func
(
a
*
Account
)
GeminiTierID
()
string
{
tierID
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"tier_id"
))
if
tierID
==
""
{
return
""
}
return
strings
.
ToUpper
(
tierID
)
}
func
(
a
*
Account
)
IsGeminiCodeAssist
()
bool
{
if
a
.
Platform
!=
PlatformGemini
||
a
.
Type
!=
AccountTypeOAuth
{
return
false
}
oauthType
:=
a
.
GeminiOAuthType
()
if
oauthType
==
""
{
return
strings
.
TrimSpace
(
a
.
GetCredential
(
"project_id"
))
!=
""
}
return
oauthType
==
"code_assist"
}
func
(
a
*
Account
)
CanGetUsage
()
bool
{
return
a
.
Type
==
AccountTypeOAuth
}
...
...
backend/internal/service/account_service.go
View file @
6cc7f997
...
...
@@ -17,6 +17,9 @@ var (
type
AccountRepository
interface
{
Create
(
ctx
context
.
Context
,
account
*
Account
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
// GetByIDs fetches accounts by IDs in a single query.
// It should return all accounts found (missing IDs are ignored).
GetByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
Account
,
error
)
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
ExistsByID
(
ctx
context
.
Context
,
id
int64
)
(
bool
,
error
)
// GetByCRSAccountID finds an account previously synced from CRS.
...
...
backend/internal/service/account_service_delete_test.go
View file @
6cc7f997
...
...
@@ -40,6 +40,10 @@ func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, erro
panic
(
"unexpected GetByID call"
)
}
func
(
s
*
accountRepoStub
)
GetByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
Account
,
error
)
{
panic
(
"unexpected GetByIDs call"
)
}
// ExistsByID 返回预设的存在性检查结果。
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
func
(
s
*
accountRepoStub
)
ExistsByID
(
ctx
context
.
Context
,
id
int64
)
(
bool
,
error
)
{
...
...
backend/internal/service/account_usage_service.go
View file @
6cc7f997
...
...
@@ -97,6 +97,8 @@ type UsageInfo struct {
FiveHour
*
UsageProgress
`json:"five_hour"`
// 5小时窗口
SevenDay
*
UsageProgress
`json:"seven_day,omitempty"`
// 7天窗口
SevenDaySonnet
*
UsageProgress
`json:"seven_day_sonnet,omitempty"`
// 7天Sonnet窗口
GeminiProDaily
*
UsageProgress
`json:"gemini_pro_daily,omitempty"`
// Gemini Pro 日配额
GeminiFlashDaily
*
UsageProgress
`json:"gemini_flash_daily,omitempty"`
// Gemini Flash 日配额
}
// ClaudeUsageResponse Anthropic API返回的usage结构
...
...
@@ -125,14 +127,16 @@ type AccountUsageService struct {
accountRepo
AccountRepository
usageLogRepo
UsageLogRepository
usageFetcher
ClaudeUsageFetcher
geminiQuotaService
*
GeminiQuotaService
}
// NewAccountUsageService 创建AccountUsageService实例
func
NewAccountUsageService
(
accountRepo
AccountRepository
,
usageLogRepo
UsageLogRepository
,
usageFetcher
ClaudeUsageFetcher
)
*
AccountUsageService
{
func
NewAccountUsageService
(
accountRepo
AccountRepository
,
usageLogRepo
UsageLogRepository
,
usageFetcher
ClaudeUsageFetcher
,
geminiQuotaService
*
GeminiQuotaService
)
*
AccountUsageService
{
return
&
AccountUsageService
{
accountRepo
:
accountRepo
,
usageLogRepo
:
usageLogRepo
,
usageFetcher
:
usageFetcher
,
geminiQuotaService
:
geminiQuotaService
,
}
}
...
...
@@ -146,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return
nil
,
fmt
.
Errorf
(
"get account failed: %w"
,
err
)
}
if
account
.
Platform
==
PlatformGemini
{
return
s
.
getGeminiUsage
(
ctx
,
account
)
}
// 只有oauth类型账号可以通过API获取usage(有profile scope)
if
account
.
CanGetUsage
()
{
var
apiResp
*
ClaudeUsageResponse
...
...
@@ -192,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return
nil
,
fmt
.
Errorf
(
"account type %s does not support usage query"
,
account
.
Type
)
}
func
(
s
*
AccountUsageService
)
getGeminiUsage
(
ctx
context
.
Context
,
account
*
Account
)
(
*
UsageInfo
,
error
)
{
now
:=
time
.
Now
()
usage
:=
&
UsageInfo
{
UpdatedAt
:
&
now
,
}
if
s
.
geminiQuotaService
==
nil
||
s
.
usageLogRepo
==
nil
{
return
usage
,
nil
}
quota
,
ok
:=
s
.
geminiQuotaService
.
QuotaForAccount
(
ctx
,
account
)
if
!
ok
{
return
usage
,
nil
}
start
:=
geminiDailyWindowStart
(
now
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini usage stats failed: %w"
,
err
)
}
totals
:=
geminiAggregateUsage
(
stats
)
resetAt
:=
geminiDailyResetTime
(
now
)
usage
.
GeminiProDaily
=
buildGeminiUsageProgress
(
totals
.
ProRequests
,
quota
.
ProRPD
,
resetAt
,
totals
.
ProTokens
,
totals
.
ProCost
,
now
)
usage
.
GeminiFlashDaily
=
buildGeminiUsageProgress
(
totals
.
FlashRequests
,
quota
.
FlashRPD
,
resetAt
,
totals
.
FlashTokens
,
totals
.
FlashCost
,
now
)
return
usage
,
nil
}
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存(1 分钟),与 API 缓存分离
func
(
s
*
AccountUsageService
)
addWindowStats
(
ctx
context
.
Context
,
account
*
Account
,
usage
*
UsageInfo
)
{
...
...
@@ -388,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据
return
info
}
func
buildGeminiUsageProgress
(
used
,
limit
int64
,
resetAt
time
.
Time
,
tokens
int64
,
cost
float64
,
now
time
.
Time
)
*
UsageProgress
{
if
limit
<=
0
{
return
nil
}
utilization
:=
(
float64
(
used
)
/
float64
(
limit
))
*
100
remainingSeconds
:=
int
(
resetAt
.
Sub
(
now
)
.
Seconds
())
if
remainingSeconds
<
0
{
remainingSeconds
=
0
}
resetCopy
:=
resetAt
return
&
UsageProgress
{
Utilization
:
utilization
,
ResetsAt
:
&
resetCopy
,
RemainingSeconds
:
remainingSeconds
,
WindowStats
:
&
WindowStats
{
Requests
:
used
,
Tokens
:
tokens
,
Cost
:
cost
,
},
}
}
backend/internal/service/admin_service.go
View file @
6cc7f997
...
...
@@ -13,7 +13,7 @@ import (
// AdminService interface defines admin management operations
type
AdminService
interface
{
// User management
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
status
,
role
,
search
string
)
([]
User
,
int64
,
error
)
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
filters
UserListFilters
)
([]
User
,
int64
,
error
)
GetUser
(
ctx
context
.
Context
,
id
int64
)
(
*
User
,
error
)
CreateUser
(
ctx
context
.
Context
,
input
*
CreateUserInput
)
(
*
User
,
error
)
UpdateUser
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateUserInput
)
(
*
User
,
error
)
...
...
@@ -35,6 +35,7 @@ type AdminService interface {
// Account management
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
int64
,
error
)
GetAccount
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
GetAccountsByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
Account
,
error
)
CreateAccount
(
ctx
context
.
Context
,
input
*
CreateAccountInput
)
(
*
Account
,
error
)
UpdateAccount
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateAccountInput
)
(
*
Account
,
error
)
DeleteAccount
(
ctx
context
.
Context
,
id
int64
)
error
...
...
@@ -69,7 +70,6 @@ type CreateUserInput struct {
Email
string
Password
string
Username
string
Wechat
string
Notes
string
Balance
float64
Concurrency
int
...
...
@@ -80,7 +80,6 @@ type UpdateUserInput struct {
Email
string
Password
string
Username
*
string
Wechat
*
string
Notes
*
string
Balance
*
float64
// 使用指针区分"未提供"和"设置为0"
Concurrency
*
int
// 使用指针区分"未提供"和"设置为0"
...
...
@@ -251,9 +250,9 @@ func NewAdminService(
}
// User management implementations
func
(
s
*
adminServiceImpl
)
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
status
,
role
,
search
string
)
([]
User
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListUsers
(
ctx
context
.
Context
,
page
,
pageSize
int
,
filters
UserListFilters
)
([]
User
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
users
,
result
,
err
:=
s
.
userRepo
.
ListWithFilters
(
ctx
,
params
,
status
,
role
,
search
)
users
,
result
,
err
:=
s
.
userRepo
.
ListWithFilters
(
ctx
,
params
,
filters
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
...
...
@@ -268,7 +267,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
user
:=
&
User
{
Email
:
input
.
Email
,
Username
:
input
.
Username
,
Wechat
:
input
.
Wechat
,
Notes
:
input
.
Notes
,
Role
:
RoleUser
,
// Always create as regular user, never admin
Balance
:
input
.
Balance
,
...
...
@@ -310,9 +308,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if
input
.
Username
!=
nil
{
user
.
Username
=
*
input
.
Username
}
if
input
.
Wechat
!=
nil
{
user
.
Wechat
=
*
input
.
Wechat
}
if
input
.
Notes
!=
nil
{
user
.
Notes
=
*
input
.
Notes
}
...
...
@@ -611,6 +606,19 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
return
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
GetAccountsByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
Account
,
error
)
{
if
len
(
ids
)
==
0
{
return
[]
*
Account
{},
nil
}
accounts
,
err
:=
s
.
accountRepo
.
GetByIDs
(
ctx
,
ids
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to get accounts by IDs: %w"
,
err
)
}
return
accounts
,
nil
}
func
(
s
*
adminServiceImpl
)
CreateAccount
(
ctx
context
.
Context
,
input
*
CreateAccountInput
)
(
*
Account
,
error
)
{
account
:=
&
Account
{
Name
:
input
.
Name
,
...
...
backend/internal/service/admin_service_create_user_test.go
View file @
6cc7f997
...
...
@@ -18,7 +18,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
Email
:
"user@test.com"
,
Password
:
"strong-pass"
,
Username
:
"tester"
,
Wechat
:
"wx"
,
Notes
:
"note"
,
Balance
:
12.5
,
Concurrency
:
7
,
...
...
@@ -31,7 +30,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
require
.
Equal
(
t
,
int64
(
10
),
user
.
ID
)
require
.
Equal
(
t
,
input
.
Email
,
user
.
Email
)
require
.
Equal
(
t
,
input
.
Username
,
user
.
Username
)
require
.
Equal
(
t
,
input
.
Wechat
,
user
.
Wechat
)
require
.
Equal
(
t
,
input
.
Notes
,
user
.
Notes
)
require
.
Equal
(
t
,
input
.
Balance
,
user
.
Balance
)
require
.
Equal
(
t
,
input
.
Concurrency
,
user
.
Concurrency
)
...
...
backend/internal/service/admin_service_delete_test.go
View file @
6cc7f997
...
...
@@ -66,7 +66,7 @@ func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationPar
panic
(
"unexpected List call"
)
}
func
(
s
*
userRepoStub
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
role
,
search
string
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
userRepoStub
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
UserListFilters
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
6cc7f997
...
...
@@ -25,7 +25,7 @@ const (
antigravityRetryMaxDelay
=
16
*
time
.
Second
)
// Antigravity 直接支持的模型
// Antigravity 直接支持的模型
(精确匹配透传)
var
antigravitySupportedModels
=
map
[
string
]
bool
{
"claude-opus-4-5-thinking"
:
true
,
"claude-sonnet-4-5"
:
true
,
...
...
@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
"gemini-3-flash"
:
true
,
"gemini-3-pro-low"
:
true
,
"gemini-3-pro-high"
:
true
,
"gemini-3-pro-preview"
:
true
,
"gemini-3-pro-image"
:
true
,
}
// Antigravity 系统默认模型映射表(不支持 → 支持)
var
antigravityModelMapping
=
map
[
string
]
string
{
"claude-3-5-sonnet-20241022"
:
"claude-sonnet-4-5"
,
"claude-3-5-sonnet-20240620"
:
"claude-sonnet-4-5"
,
"claude-sonnet-4-5-20250929"
:
"claude-sonnet-4-5-thinking"
,
"claude-opus-4"
:
"claude-opus-4-5-thinking"
,
"claude-opus-4-5-20251101"
:
"claude-opus-4-5-thinking"
,
"claude-haiku-4"
:
"gemini-3-flash"
,
"claude-haiku-4-5"
:
"gemini-3-flash"
,
"claude-3-haiku-20240307"
:
"gemini-3-flash"
,
"claude-haiku-4-5-20251001"
:
"gemini-3-flash"
,
// 生图模型:官方名 → Antigravity 内部名
"gemini-3-pro-image-preview"
:
"gemini-3-pro-image"
,
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
var
antigravityPrefixMapping
=
[]
struct
{
prefix
string
target
string
}{
// 长前缀优先
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"claude-haiku-4-5"
,
"gemini-3-flash"
},
// claude-haiku-4-5-xxx
{
"claude-opus-4-5"
,
"claude-opus-4-5-thinking"
},
{
"claude-3-haiku"
,
"gemini-3-flash"
},
// 旧版 claude-3-haiku-xxx
{
"claude-sonnet-4"
,
"claude-sonnet-4-5"
},
{
"claude-haiku-4"
,
"gemini-3-flash"
},
{
"claude-opus-4"
,
"claude-opus-4-5-thinking"
},
{
"gemini-3-pro"
,
"gemini-3-pro-high"
},
// gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
...
...
@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
}
// getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func
(
s
*
AntigravityGatewayService
)
getMappedModel
(
account
*
Account
,
requestedModel
string
)
string
{
// 1.
优先使用
账户级映射(
复用现有方法
)
// 1. 账户级映射(
用户自定义优先
)
if
mapped
:=
account
.
GetMappedModel
(
requestedModel
);
mapped
!=
requestedModel
{
return
mapped
}
// 2.
系统默认映射
if
mapped
,
ok
:=
antigravityModelMapping
[
requestedModel
]
;
ok
{
return
mapped
// 2.
直接支持的模型透传
if
antigravitySupportedModels
[
requestedModel
]
{
return
requestedModel
}
// 3. Gemini 模型透传
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
requestedModel
// 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview)
for
_
,
pm
:=
range
antigravityPrefixMapping
{
if
strings
.
HasPrefix
(
requestedModel
,
pm
.
prefix
)
{
return
pm
.
target
}
}
// 4.
Claude 前缀透传直接支持的
模型
if
antigravitySupportedModels
[
requestedModel
]
{
// 4.
Gemini 模型透传(未匹配到前缀的 gemini
模型
)
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
requestedModel
}
...
...
@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
}
// IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func
(
s
*
AntigravityGatewayService
)
IsModelSupported
(
requestedModel
string
)
bool
{
// 直接支持的模型
if
antigravitySupportedModels
[
requestedModel
]
{
return
true
}
// 可映射的模型
if
_
,
ok
:=
antigravityModelMapping
[
requestedModel
];
ok
{
return
true
}
// Gemini 前缀透传
if
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
{
return
true
}
// Claude 模型支持(通过默认映射)
if
strings
.
HasPrefix
(
requestedModel
,
"claude-"
)
{
return
true
}
return
false
return
strings
.
HasPrefix
(
requestedModel
,
"claude-"
)
||
strings
.
HasPrefix
(
requestedModel
,
"gemini-"
)
}
// TestConnectionResult 测试连接结果
...
...
@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if
bodyJSON
,
err
:=
json
.
Marshal
(
geminiBody
);
err
==
nil
{
truncated
:=
string
(
bodyJSON
)
if
len
(
truncated
)
>
2000
{
truncated
=
truncated
[
:
2000
]
+
"..."
}
log
.
Printf
(
"[Debug] Transformed Gemini request: %s"
,
truncated
)
}
// 构建上游 action
action
:=
"generateContent"
if
claudeReq
.
Stream
{
...
...
backend/internal/service/antigravity_model_mapping_test.go
View file @
6cc7f997
...
...
@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
name
:
"系统映射 - claude-sonnet-4-5-20250929"
,
requestedModel
:
"claude-sonnet-4-5-20250929"
,
accountMapping
:
nil
,
expected
:
"claude-sonnet-4-5
-thinking
"
,
expected
:
"claude-sonnet-4-5"
,
},
// 3. Gemini 透传
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment