"backend/internal/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "aaaa68ea7fbd85611f02b18f1dd3e633f35c8101"
Commit 6cc7f997 authored by song's avatar song
Browse files

merge: 合并 upstream/main

parents 95d09f60 106e59b7
...@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) { ...@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) {
_ = rdb.Close() _ = rdb.Close()
}() }()
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
ctx := context.Background() ctx := context.Background()
for _, size := range []int{10, 100, 1000} { for _, size := range []int{10, 100, 1000} {
......
...@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct { ...@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() { func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest() s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes) s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
} }
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
...@@ -218,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { ...@@ -218,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
} }
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0 // When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999) cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
...@@ -232,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() { ...@@ -232,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require.Equal(s.T(), 0, cur) require.Equal(s.T(), 0, cur)
} }
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) { func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite)) suite.Run(t, new(ConcurrencyCacheSuite))
} }
...@@ -40,7 +40,6 @@ func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *servic ...@@ -40,7 +40,6 @@ func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *servic
SetBalance(u.Balance). SetBalance(u.Balance).
SetConcurrency(u.Concurrency). SetConcurrency(u.Concurrency).
SetUsername(u.Username). SetUsername(u.Username).
SetWechat(u.Wechat).
SetNotes(u.Notes) SetNotes(u.Notes)
if !u.CreatedAt.IsZero() { if !u.CreatedAt.IsZero() {
create.SetCreatedAt(u.CreatedAt) create.SetCreatedAt(u.CreatedAt)
......
...@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { ...@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if existing != checksum { if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。 // 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。 // 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum) return fmt.Errorf(
"migration %s checksum mismatch (db=%s file=%s)\n"+
"This means the migration file was modified after being applied to the database.\n"+
"Solutions:\n"+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
name, existing, checksum, name, name,
)
} }
continue // 迁移已应用且校验和匹配,跳过 continue // 迁移已应用且校验和匹配,跳过
} }
......
...@@ -23,7 +23,6 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { ...@@ -23,7 +23,6 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// users: columns required by repository queries // users: columns required by repository queries
requireColumn(t, tx, "users", "username", "character varying", 100, false) requireColumn(t, tx, "users", "username", "character varying", 100, false)
requireColumn(t, tx, "users", "wechat", "character varying", 100, false)
requireColumn(t, tx, "users", "notes", "text", 0, false) requireColumn(t, tx, "users", "notes", "text", 0, false)
// accounts: schedulable and rate-limit fields // accounts: schedulable and rate-limit fields
......
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// UserAttributeDefinitionRepository implementation
type userAttributeDefinitionRepository struct {
client *dbent.Client
}
// NewUserAttributeDefinitionRepository creates a new repository instance
func NewUserAttributeDefinitionRepository(client *dbent.Client) service.UserAttributeDefinitionRepository {
return &userAttributeDefinitionRepository{client: client}
}
func (r *userAttributeDefinitionRepository) Create(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
created, err := client.UserAttributeDefinition.Create().
SetKey(def.Key).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrAttributeKeyExists)
}
def.ID = created.ID
def.DisplayOrder = created.DisplayOrder
def.CreatedAt = created.CreatedAt
def.UpdatedAt = created.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) GetByID(ctx context.Context, id int64) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) GetByKey(ctx context.Context, key string) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) Update(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
updated, err := client.UserAttributeDefinition.UpdateOneID(def.ID).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetDisplayOrder(def.DisplayOrder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, service.ErrAttributeKeyExists)
}
def.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeDefinition.Delete().
Where(userattributedefinition.IDEQ(id)).
Exec(ctx)
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
func (r *userAttributeDefinitionRepository) List(ctx context.Context, enabledOnly bool) ([]service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
q := client.UserAttributeDefinition.Query()
if enabledOnly {
q = q.Where(userattributedefinition.EnabledEQ(true))
}
entities, err := q.Order(dbent.Asc(userattributedefinition.FieldDisplayOrder)).All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeDefinition, 0, len(entities))
for _, e := range entities {
result = append(result, *defEntityToService(e))
}
return result, nil
}
func (r *userAttributeDefinitionRepository) UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error {
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for id, order := range orders {
if _, err := tx.UserAttributeDefinition.UpdateOneID(id).
SetDisplayOrder(order).
Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
}
return tx.Commit()
}
func (r *userAttributeDefinitionRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
client := clientFromContext(ctx, r.client)
return client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Exist(ctx)
}
// UserAttributeValueRepository implementation
type userAttributeValueRepository struct {
client *dbent.Client
}
// NewUserAttributeValueRepository creates a new repository instance
func NewUserAttributeValueRepository(client *dbent.Client) service.UserAttributeValueRepository {
return &userAttributeValueRepository{client: client}
}
func (r *userAttributeValueRepository) GetByUserID(ctx context.Context, userID int64) ([]service.UserAttributeValue, error) {
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDEQ(userID)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) GetByUserIDs(ctx context.Context, userIDs []int64) ([]service.UserAttributeValue, error) {
if len(userIDs) == 0 {
return []service.UserAttributeValue{}, nil
}
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDIn(userIDs...)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) UpsertBatch(ctx context.Context, userID int64, inputs []service.UpdateUserAttributeInput) error {
if len(inputs) == 0 {
return nil
}
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, input := range inputs {
// Use upsert (ON CONFLICT DO UPDATE)
err := tx.UserAttributeValue.Create().
SetUserID(userID).
SetAttributeID(input.AttributeID).
SetValue(input.Value).
OnConflictColumns(userattributevalue.FieldUserID, userattributevalue.FieldAttributeID).
UpdateValue().
UpdateUpdatedAt().
Exec(ctx)
if err != nil {
return err
}
}
return tx.Commit()
}
func (r *userAttributeValueRepository) DeleteByAttributeID(ctx context.Context, attributeID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.AttributeIDEQ(attributeID)).
Exec(ctx)
return err
}
func (r *userAttributeValueRepository) DeleteByUserID(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.UserIDEQ(userID)).
Exec(ctx)
return err
}
// Helper functions for entity to service conversion
func defEntityToService(e *dbent.UserAttributeDefinition) *service.UserAttributeDefinition {
if e == nil {
return nil
}
return &service.UserAttributeDefinition{
ID: e.ID,
Key: e.Key,
Name: e.Name,
Description: e.Description,
Type: service.UserAttributeType(e.Type),
Options: toServiceOptions(e.Options),
Required: e.Required,
Validation: toServiceValidation(e.Validation),
Placeholder: e.Placeholder,
DisplayOrder: e.DisplayOrder,
Enabled: e.Enabled,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
}
// Type conversion helpers (map types <-> service types)
func toEntOptions(opts []service.UserAttributeOption) []map[string]any {
if opts == nil {
return []map[string]any{}
}
result := make([]map[string]any, len(opts))
for i, o := range opts {
result[i] = map[string]any{"value": o.Value, "label": o.Label}
}
return result
}
func toServiceOptions(opts []map[string]any) []service.UserAttributeOption {
if opts == nil {
return []service.UserAttributeOption{}
}
result := make([]service.UserAttributeOption, len(opts))
for i, o := range opts {
result[i] = service.UserAttributeOption{
Value: getString(o, "value"),
Label: getString(o, "label"),
}
}
return result
}
func toEntValidation(v service.UserAttributeValidation) map[string]any {
result := map[string]any{}
if v.MinLength != nil {
result["min_length"] = *v.MinLength
}
if v.MaxLength != nil {
result["max_length"] = *v.MaxLength
}
if v.Min != nil {
result["min"] = *v.Min
}
if v.Max != nil {
result["max"] = *v.Max
}
if v.Pattern != nil {
result["pattern"] = *v.Pattern
}
if v.Message != nil {
result["message"] = *v.Message
}
return result
}
func toServiceValidation(v map[string]any) service.UserAttributeValidation {
result := service.UserAttributeValidation{}
if val := getInt(v, "min_length"); val != nil {
result.MinLength = val
}
if val := getInt(v, "max_length"); val != nil {
result.MaxLength = val
}
if val := getInt(v, "min"); val != nil {
result.Min = val
}
if val := getInt(v, "max"); val != nil {
result.Max = val
}
if val := getStringPtr(v, "pattern"); val != nil {
result.Pattern = val
}
if val := getStringPtr(v, "message"); val != nil {
result.Message = val
}
return result
}
// Helper functions for type conversion
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getStringPtr(m map[string]any, key string) *string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return &s
}
}
return nil
}
func getInt(m map[string]any, key string) *int {
if v, ok := m[key]; ok {
switch n := v.(type) {
case int:
return &n
case int64:
i := int(n)
return &i
case float64:
i := int(n)
return &i
}
}
return nil
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -50,7 +51,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -50,7 +51,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
created, err := txClient.User.Create(). created, err := txClient.User.Create().
SetEmail(userIn.Email). SetEmail(userIn.Email).
SetUsername(userIn.Username). SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes). SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash). SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role). SetRole(userIn.Role).
...@@ -133,7 +133,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -133,7 +133,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
updated, err := txClient.User.UpdateOneID(userIn.ID). updated, err := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email). SetEmail(userIn.Email).
SetUsername(userIn.Username). SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes). SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash). SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role). SetRole(userIn.Role).
...@@ -171,28 +170,38 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error { ...@@ -171,28 +170,38 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error {
} }
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, service.UserListFilters{})
} }
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) { func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
q := r.client.User.Query() q := r.client.User.Query()
if status != "" { if filters.Status != "" {
q = q.Where(dbuser.StatusEQ(status)) q = q.Where(dbuser.StatusEQ(filters.Status))
} }
if role != "" { if filters.Role != "" {
q = q.Where(dbuser.RoleEQ(role)) q = q.Where(dbuser.RoleEQ(filters.Role))
} }
if search != "" { if filters.Search != "" {
q = q.Where( q = q.Where(
dbuser.Or( dbuser.Or(
dbuser.EmailContainsFold(search), dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(search), dbuser.UsernameContainsFold(filters.Search),
dbuser.WechatContainsFold(search),
), ),
) )
} }
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
}
q = q.Where(dbuser.IDIn(allowedUserIDs...))
}
total, err := q.Clone().Count(ctx) total, err := q.Clone().Count(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
...@@ -252,6 +261,59 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -252,6 +261,59 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil return outUsers, paginationResultFromTotal(int64(total), params), nil
} }
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
if len(attrs) == 0 {
return nil
}
// For each attribute filter, get the set of matching user IDs
// Then intersect all sets to get users matching ALL filters
var resultSet map[int64]struct{}
first := true
for attrID, value := range attrs {
// Query user_attribute_values for this attribute
values, err := r.client.UserAttributeValue.Query().
Where(
userattributevalue.AttributeIDEQ(attrID),
userattributevalue.ValueContainsFold(value),
).
All(ctx)
if err != nil {
continue
}
currentSet := make(map[int64]struct{}, len(values))
for _, v := range values {
currentSet[v.UserID] = struct{}{}
}
if first {
resultSet = currentSet
first = false
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
}
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
}
result := make([]int64, 0, len(resultSet))
for userID := range resultSet {
result = append(result, userID)
}
return result
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
......
...@@ -166,7 +166,7 @@ func (s *UserRepoSuite) TestListWithFilters_Status() { ...@@ -166,7 +166,7 @@ func (s *UserRepoSuite) TestListWithFilters_Status() {
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive}) s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled}) s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(service.StatusActive, users[0].Status) s.Require().Equal(service.StatusActive, users[0].Status)
...@@ -176,7 +176,7 @@ func (s *UserRepoSuite) TestListWithFilters_Role() { ...@@ -176,7 +176,7 @@ func (s *UserRepoSuite) TestListWithFilters_Role() {
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser}) s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin}) s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Role: service.RoleAdmin})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(service.RoleAdmin, users[0].Role) s.Require().Equal(service.RoleAdmin, users[0].Role)
...@@ -186,7 +186,7 @@ func (s *UserRepoSuite) TestListWithFilters_Search() { ...@@ -186,7 +186,7 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"}) s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"}) s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "alice"})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Contains(users[0].Email, "alice") s.Require().Contains(users[0].Email, "alice")
...@@ -196,22 +196,12 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() { ...@@ -196,22 +196,12 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"}) s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"}) s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "john"})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal("JohnDoe", users[0].Username) s.Require().Equal("JohnDoe", users[0].Username)
} }
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
s.mustCreateUser(&service.User{Email: "w1@test.com", Wechat: "wx_hello"})
s.mustCreateUser(&service.User{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("wx_hello", users[0].Wechat)
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive}) user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
groupActive := s.mustCreateGroup("g-sub-active") groupActive := s.mustCreateGroup("g-sub-active")
...@@ -226,7 +216,7 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { ...@@ -226,7 +216,7 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
c.SetExpiresAt(time.Now().Add(-1 * time.Hour)) c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
}) })
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "sub@"})
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Len(users, 1, "expected 1 user") s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription") s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
...@@ -238,7 +228,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { ...@@ -238,7 +228,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
s.mustCreateUser(&service.User{ s.mustCreateUser(&service.User{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a",
Role: service.RoleUser, Role: service.RoleUser,
Status: service.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
...@@ -246,7 +235,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { ...@@ -246,7 +235,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
target := s.mustCreateUser(&service.User{ target := s.mustCreateUser(&service.User{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b",
Role: service.RoleAdmin, Role: service.RoleAdmin,
Status: service.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
...@@ -257,7 +245,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { ...@@ -257,7 +245,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
Status: service.StatusDisabled, Status: service.StatusDisabled,
}) })
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")
...@@ -448,7 +436,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { ...@@ -448,7 +436,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := s.mustCreateUser(&service.User{ user1 := s.mustCreateUser(&service.User{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a",
Role: service.RoleUser, Role: service.RoleUser,
Status: service.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
...@@ -456,7 +443,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { ...@@ -456,7 +443,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user2 := s.mustCreateUser(&service.User{ user2 := s.mustCreateUser(&service.User{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b",
Role: service.RoleAdmin, Role: service.RoleAdmin,
Status: service.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
...@@ -501,7 +487,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { ...@@ -501,7 +487,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user1.Concurrency+3, got5.Concurrency) s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
params := pagination.PaginationParams{Page: 1, PageSize: 10} params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, params, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")
......
...@@ -15,7 +15,14 @@ import ( ...@@ -15,7 +15,14 @@ import (
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数 // ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景 // 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache { func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes) waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds())
if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout {
waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds())
}
if waitTTLSeconds <= 0 {
waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60
}
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
} }
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
...@@ -29,6 +36,8 @@ var ProviderSet = wire.NewSet( ...@@ -29,6 +36,8 @@ var ProviderSet = wire.NewSet(
NewUsageLogRepository, NewUsageLogRepository,
NewSettingRepository, NewSettingRepository,
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
......
...@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) { ...@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1, "id": 1,
"email": "alice@example.com", "email": "alice@example.com",
"username": "alice", "username": "alice",
"wechat": "wx_alice",
"notes": "hello", "notes": "hello",
"role": "user", "role": "user",
"balance": 12.5, "balance": 12.5,
...@@ -348,7 +347,6 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -348,7 +347,6 @@ func newContractDeps(t *testing.T) *contractDeps {
ID: 1, ID: 1,
Email: "alice@example.com", Email: "alice@example.com",
Username: "alice", Username: "alice",
Wechat: "wx_alice",
Notes: "hello", Notes: "hello",
Role: service.RoleUser, Role: service.RoleUser,
Balance: 12.5, Balance: 12.5,
...@@ -503,7 +501,7 @@ func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationPar ...@@ -503,7 +501,7 @@ func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationPar
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) { func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
......
...@@ -54,6 +54,9 @@ func RegisterAdminRoutes( ...@@ -54,6 +54,9 @@ func RegisterAdminRoutes(
// 使用记录管理 // 使用记录管理
registerUsageRoutes(admin, h) registerUsageRoutes(admin, h)
// 用户属性管理
registerUserAttributeRoutes(admin, h)
} }
} }
...@@ -82,6 +85,10 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -82,6 +85,10 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.POST("/:id/balance", h.Admin.User.UpdateBalance) users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys) users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage) users.GET("/:id/usage", h.Admin.User.GetUserUsage)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
users.PUT("/:id/attributes", h.Admin.UserAttribute.UpdateUserAttributes)
} }
} }
...@@ -110,6 +117,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -110,6 +117,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.DELETE("/:id", h.Admin.Account.Delete) accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test) accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh) accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
accounts.GET("/:id/stats", h.Admin.Account.GetStats) accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage) accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
...@@ -119,6 +127,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -119,6 +127,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate) accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials) accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes // Claude OAuth routes
...@@ -242,3 +251,15 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -242,3 +251,15 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys) usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
} }
} }
func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs := admin.Group("/user-attributes")
{
attrs.GET("", h.Admin.UserAttribute.ListDefinitions)
attrs.POST("", h.Admin.UserAttribute.CreateDefinition)
attrs.POST("/batch", h.Admin.UserAttribute.GetBatchUserAttributes)
attrs.PUT("/reorder", h.Admin.UserAttribute.ReorderDefinitions)
attrs.PUT("/:id", h.Admin.UserAttribute.UpdateDefinition)
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}
...@@ -3,6 +3,7 @@ package service ...@@ -3,6 +3,7 @@ package service
import ( import (
"encoding/json" "encoding/json"
"strconv" "strconv"
"strings"
"time" "time"
) )
...@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool { ...@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini return a.Platform == PlatformGemini
} }
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
}
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
return "code_assist"
}
return oauthType
}
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
}
func (a *Account) IsGeminiCodeAssist() bool {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return false
}
oauthType := a.GeminiOAuthType()
if oauthType == "" {
return strings.TrimSpace(a.GetCredential("project_id")) != ""
}
return oauthType == "code_assist"
}
func (a *Account) CanGetUsage() bool { func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth return a.Type == AccountTypeOAuth
} }
......
...@@ -17,6 +17,9 @@ var ( ...@@ -17,6 +17,9 @@ var (
type AccountRepository interface { type AccountRepository interface {
Create(ctx context.Context, account *Account) error Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error) GetByID(ctx context.Context, id int64) (*Account, error)
// GetByIDs fetches accounts by IDs in a single query.
// It should return all accounts found (missing IDs are ignored).
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查 // ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
ExistsByID(ctx context.Context, id int64) (bool, error) ExistsByID(ctx context.Context, id int64) (bool, error)
// GetByCRSAccountID finds an account previously synced from CRS. // GetByCRSAccountID finds an account previously synced from CRS.
......
...@@ -40,6 +40,10 @@ func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, erro ...@@ -40,6 +40,10 @@ func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, erro
panic("unexpected GetByID call") panic("unexpected GetByID call")
} }
func (s *accountRepoStub) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
panic("unexpected GetByIDs call")
}
// ExistsByID 返回预设的存在性检查结果。 // ExistsByID 返回预设的存在性检查结果。
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。 // 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) { func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) {
......
...@@ -93,10 +93,12 @@ type UsageProgress struct { ...@@ -93,10 +93,12 @@ type UsageProgress struct {
// UsageInfo 账号使用量信息 // UsageInfo 账号使用量信息
type UsageInfo struct { type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口 FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口 SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
} }
// ClaudeUsageResponse Anthropic API返回的usage结构 // ClaudeUsageResponse Anthropic API返回的usage结构
...@@ -122,17 +124,19 @@ type ClaudeUsageFetcher interface { ...@@ -122,17 +124,19 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
type AccountUsageService struct { type AccountUsageService struct {
accountRepo AccountRepository accountRepo AccountRepository
usageLogRepo UsageLogRepository usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService { func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
return &AccountUsageService{ return &AccountUsageService{
accountRepo: accountRepo, accountRepo: accountRepo,
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher, usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
} }
} }
...@@ -146,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U ...@@ -146,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("get account failed: %w", err) return nil, fmt.Errorf("get account failed: %w", err)
} }
if account.Platform == PlatformGemini {
return s.getGeminiUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage(有profile scope) // 只有oauth类型账号可以通过API获取usage(有profile scope)
if account.CanGetUsage() { if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse var apiResp *ClaudeUsageResponse
...@@ -192,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U ...@@ -192,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type) return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
} }
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{
UpdatedAt: &now,
}
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
return usage, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
return usage, nil
}
// addWindowStats 为 usage 数据添加窗口期统计 // addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存(1 分钟),与 API 缓存分离 // 使用独立缓存(1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
...@@ -388,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn ...@@ -388,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据 // Setup Token无法获取7d数据
return info return info
} }
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
if limit <= 0 {
return nil
}
utilization := (float64(used) / float64(limit)) * 100
remainingSeconds := int(resetAt.Sub(now).Seconds())
if remainingSeconds < 0 {
remainingSeconds = 0
}
resetCopy := resetAt
return &UsageProgress{
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,
Cost: cost,
},
}
}
...@@ -13,7 +13,7 @@ import ( ...@@ -13,7 +13,7 @@ import (
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
type AdminService interface { type AdminService interface {
// User management // User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error) GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
...@@ -35,6 +35,7 @@ type AdminService interface { ...@@ -35,6 +35,7 @@ type AdminService interface {
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error DeleteAccount(ctx context.Context, id int64) error
...@@ -69,7 +70,6 @@ type CreateUserInput struct { ...@@ -69,7 +70,6 @@ type CreateUserInput struct {
Email string Email string
Password string Password string
Username string Username string
Wechat string
Notes string Notes string
Balance float64 Balance float64
Concurrency int Concurrency int
...@@ -80,7 +80,6 @@ type UpdateUserInput struct { ...@@ -80,7 +80,6 @@ type UpdateUserInput struct {
Email string Email string
Password string Password string
Username *string Username *string
Wechat *string
Notes *string Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0" Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
...@@ -251,9 +250,9 @@ func NewAdminService( ...@@ -251,9 +250,9 @@ func NewAdminService(
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search) users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
...@@ -268,7 +267,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu ...@@ -268,7 +267,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
user := &User{ user := &User{
Email: input.Email, Email: input.Email,
Username: input.Username, Username: input.Username,
Wechat: input.Wechat,
Notes: input.Notes, Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance, Balance: input.Balance,
...@@ -310,9 +308,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -310,9 +308,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if input.Username != nil { if input.Username != nil {
user.Username = *input.Username user.Username = *input.Username
} }
if input.Wechat != nil {
user.Wechat = *input.Wechat
}
if input.Notes != nil { if input.Notes != nil {
user.Notes = *input.Notes user.Notes = *input.Notes
} }
...@@ -611,6 +606,19 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, ...@@ -611,6 +606,19 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
return s.accountRepo.GetByID(ctx, id) return s.accountRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
if len(ids) == 0 {
return []*Account{}, nil
}
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
}
return accounts, nil
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
account := &Account{ account := &Account{
Name: input.Name, Name: input.Name,
......
...@@ -18,7 +18,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) { ...@@ -18,7 +18,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
Email: "user@test.com", Email: "user@test.com",
Password: "strong-pass", Password: "strong-pass",
Username: "tester", Username: "tester",
Wechat: "wx",
Notes: "note", Notes: "note",
Balance: 12.5, Balance: 12.5,
Concurrency: 7, Concurrency: 7,
...@@ -31,7 +30,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) { ...@@ -31,7 +30,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
require.Equal(t, int64(10), user.ID) require.Equal(t, int64(10), user.ID)
require.Equal(t, input.Email, user.Email) require.Equal(t, input.Email, user.Email)
require.Equal(t, input.Username, user.Username) require.Equal(t, input.Username, user.Username)
require.Equal(t, input.Wechat, user.Wechat)
require.Equal(t, input.Notes, user.Notes) require.Equal(t, input.Notes, user.Notes)
require.Equal(t, input.Balance, user.Balance) require.Equal(t, input.Balance, user.Balance)
require.Equal(t, input.Concurrency, user.Concurrency) require.Equal(t, input.Concurrency, user.Concurrency)
......
...@@ -66,7 +66,7 @@ func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationPar ...@@ -66,7 +66,7 @@ func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationPar
panic("unexpected List call") panic("unexpected List call")
} }
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") panic("unexpected ListWithFilters call")
} }
......
...@@ -25,7 +25,7 @@ const ( ...@@ -25,7 +25,7 @@ const (
antigravityRetryMaxDelay = 16 * time.Second antigravityRetryMaxDelay = 16 * time.Second
) )
// Antigravity 直接支持的模型 // Antigravity 直接支持的模型(精确匹配透传)
var antigravitySupportedModels = map[string]bool{ var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true, "claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true, "claude-sonnet-4-5": true,
...@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{ ...@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
"gemini-3-flash": true, "gemini-3-flash": true,
"gemini-3-pro-low": true, "gemini-3-pro-low": true,
"gemini-3-pro-high": true, "gemini-3-pro-high": true,
"gemini-3-pro-preview": true,
"gemini-3-pro-image": true, "gemini-3-pro-image": true,
} }
// Antigravity 系统默认模型映射表(不支持 → 支持) // Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
var antigravityModelMapping = map[string]string{ // 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5", var antigravityPrefixMapping = []struct {
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5", prefix string
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking", target string
"claude-opus-4": "claude-opus-4-5-thinking", }{
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking", // 长前缀优先
"claude-haiku-4": "gemini-3-flash", {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
"claude-haiku-4-5": "gemini-3-flash", {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
"claude-3-haiku-20240307": "gemini-3-flash", {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
"claude-haiku-4-5-20251001": "gemini-3-flash", {"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
// 生图模型:官方名 → Antigravity 内部名 {"claude-opus-4-5", "claude-opus-4-5-thinking"},
"gemini-3-pro-image-preview": "gemini-3-pro-image", {"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "gemini-3-flash"},
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
} }
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发 // AntigravityGatewayService 处理 Antigravity 平台的 API 转发
...@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider ...@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
} }
// getMappedModel 获取映射后的模型名 // getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string { func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 优先使用账户级映射(复用现有方法 // 1. 账户级映射(用户自定义优先
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel { if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped return mapped
} }
// 2. 系统默认映射 // 2. 直接支持的模型透传
if mapped, ok := antigravityModelMapping[requestedModel]; ok { if antigravitySupportedModels[requestedModel] {
return mapped return requestedModel
} }
// 3. Gemini 模型透传 // 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview)
if strings.HasPrefix(requestedModel, "gemini-") { for _, pm := range antigravityPrefixMapping {
return requestedModel if strings.HasPrefix(requestedModel, pm.prefix) {
return pm.target
}
} }
// 4. Claude 前缀透传直接支持的模型 // 4. Gemini 模型透传(未匹配到前缀的 gemini 模型
if antigravitySupportedModels[requestedModel] { if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel return requestedModel
} }
...@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo ...@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
} }
// IsModelSupported 检查模型是否被支持 // IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool { func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
// 直接支持的模型 return strings.HasPrefix(requestedModel, "claude-") ||
if antigravitySupportedModels[requestedModel] { strings.HasPrefix(requestedModel, "gemini-")
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
} }
// TestConnectionResult 测试连接结果 // TestConnectionResult 测试连接结果
...@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err) return nil, fmt.Errorf("transform request: %w", err)
} }
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action // 构建上游 action
action := "generateContent" action := "generateContent"
if claudeReq.Stream { if claudeReq.Stream {
......
...@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { ...@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
name: "系统映射 - claude-sonnet-4-5-20250929", name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929", requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5-thinking", expected: "claude-sonnet-4-5",
}, },
// 3. Gemini 透传 // 3. Gemini 透传
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment