Commit b764d3b8 authored by ius's avatar ius
Browse files

Merge remote-tracking branch 'origin/main' into feat/billing-ledger-decouple-usage-log-20260312

parents 611fd884 826090e0
package gemini
import "testing"
func TestDefaultModels_ContainsImageModels(t *testing.T) {
t.Parallel()
models := DefaultModels()
byName := make(map[string]Model, len(models))
for _, model := range models {
byName[model.Name] = model
}
required := []string{
"models/gemini-2.5-flash-image",
"models/gemini-3.1-flash-image",
}
for _, name := range required {
model, ok := byName[name]
if !ok {
t.Fatalf("expected fallback model %q to exist", name)
}
if len(model.SupportedGenerationMethods) == 0 {
t.Fatalf("expected fallback model %q to advertise generation methods", name)
}
}
}
...@@ -13,10 +13,12 @@ type Model struct { ...@@ -13,10 +13,12 @@ type Model struct {
var DefaultModels = []Model{ var DefaultModels = []Model{
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""}, {ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""}, {ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""}, {ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
} }
// DefaultTestModel is the default model to preselect in test flows. // DefaultTestModel is the default model to preselect in test flows.
......
package geminicli
import "testing"
func TestDefaultModels_ContainsImageModels(t *testing.T) {
t.Parallel()
byID := make(map[string]Model, len(DefaultModels))
for _, model := range DefaultModels {
byID[model.ID] = model
}
required := []string{
"gemini-2.5-flash-image",
"gemini-3.1-flash-image",
}
for _, id := range required {
if _, ok := byID[id]; !ok {
t.Fatalf("expected curated Gemini model %q to exist", id)
}
}
}
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"strconv" "strconv"
"strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
...@@ -50,6 +51,18 @@ type accountRepository struct { ...@@ -50,6 +51,18 @@ type accountRepository struct {
schedulerCache service.SchedulerCache schedulerCache service.SchedulerCache
} }
var schedulerNeutralExtraKeyPrefixes = []string{
"codex_primary_",
"codex_secondary_",
"codex_5h_",
"codex_7d_",
}
var schedulerNeutralExtraKeys = map[string]struct{}{
"codex_usage_updated_at": {},
"session_window_utilization": {},
}
// NewAccountRepository 创建账户仓储实例。 // NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
...@@ -1185,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m ...@@ -1185,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
if affected == 0 { if affected == 0 {
return service.ErrAccountNotFound return service.ErrAccountNotFound
} }
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
} }
} else {
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
r.syncSchedulerAccountSnapshot(ctx, id)
}
return nil return nil
} }
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
if len(updates) == 0 {
return false
}
for key := range updates {
if isSchedulerNeutralExtraKey(key) {
continue
}
return true
}
return false
}
func isSchedulerNeutralExtraKey(key string) bool {
key = strings.TrimSpace(key)
if key == "" {
return false
}
if _, ok := schedulerNeutralExtraKeys[key]; ok {
return true
}
for _, prefix := range schedulerNeutralExtraKeyPrefixes {
if strings.HasPrefix(key, prefix) {
return true
}
}
return false
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 { if len(ids) == 0 {
return 0, nil return 0, nil
......
...@@ -23,6 +23,7 @@ type AccountRepoSuite struct { ...@@ -23,6 +23,7 @@ type AccountRepoSuite struct {
type schedulerCacheRecorder struct { type schedulerCacheRecorder struct {
setAccounts []*service.Account setAccounts []*service.Account
accounts map[int64]*service.Account
} }
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) { func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
...@@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service ...@@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service
} }
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) { func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
if s.accounts == nil {
return nil, nil return nil, nil
}
return s.accounts[accountID], nil
} }
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error { func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
s.setAccounts = append(s.setAccounts, account) s.setAccounts = append(s.setAccounts, account)
if s.accounts == nil {
s.accounts = make(map[int64]*service.Account)
}
if account != nil {
s.accounts[account.ID] = account
}
return nil return nil
} }
...@@ -623,6 +633,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { ...@@ -623,6 +633,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
s.Require().Equal("val", got.Extra["key"]) s.Require().Equal("val", got.Extra["key"])
} }
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra-neutral",
Platform: service.PlatformOpenAI,
Extra: map[string]any{"codex_usage_updated_at": "old"},
})
cacheRecorder := &schedulerCacheRecorder{
accounts: map[int64]*service.Account{
account.ID: {
ID: account.ID,
Platform: account.Platform,
Status: service.StatusDisabled,
Extra: map[string]any{
"codex_usage_updated_at": "old",
},
},
},
}
s.repo.schedulerCache = cacheRecorder
updates := map[string]any{
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
"codex_5h_used_percent": 88.5,
"session_window_utilization": 0.42,
}
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"])
s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"])
s.Require().Equal(0.42, got.Extra["session_window_utilization"])
var outboxCount int
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
s.Require().Zero(outboxCount)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().NotNil(cacheRecorder.accounts[account.ID])
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
}
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra-codex-exhausted",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Extra: map[string]any{},
})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
s.Require().NoError(err)
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
"codex_7d_reset_after_seconds": 86400,
}))
var count int
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
s.Require().NoError(err)
s.Require().Equal(0, count)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
}
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra-mixed",
Platform: service.PlatformAntigravity,
Extra: map[string]any{},
})
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
s.Require().NoError(err)
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
"mixed_scheduling": true,
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
}))
var count int
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
s.Require().NoError(err)
s.Require().Equal(1, count)
}
// --- GetByCRSAccountID --- // --- GetByCRSAccountID ---
func (s *AccountRepoSuite) TestGetByCRSAccountID() { func (s *AccountRepoSuite) TestGetByCRSAccountID() {
......
...@@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo ...@@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
return updated.QuotaUsed, nil return updated.QuotaUsed, nil
} }
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
// as quota_exhausted, and returns the latest quota state in one round trip.
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
query := `
UPDATE api_keys
SET
quota_used = quota_used + $1,
status = CASE
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
ELSE status
END,
updated_at = NOW()
WHERE id = $3 AND deleted_at IS NULL
RETURNING quota_used, quota, key, status
`
state := &service.APIKeyQuotaUsageState{}
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
if err == sql.ErrNoRows {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return state, nil
}
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
affected, err := r.client.APIKey.Update(). affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
......
...@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { ...@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
} }
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
user := s.mustCreateUser("quota-state@test.com")
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
key.Quota = 3
key.QuotaUsed = 1
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
s.Require().NotNil(state)
s.Require().Equal(3.5, state.QuotaUsed)
s.Require().Equal(3.0, state.Quota)
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
s.Require().Equal(key.Key, state.Key)
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(3.5, got.QuotaUsed)
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
}
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 // TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 // 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
func TestIncrementQuotaUsed_Concurrent(t *testing.T) { func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
......
...@@ -16,19 +16,7 @@ type opsRepository struct { ...@@ -16,19 +16,7 @@ type opsRepository struct {
db *sql.DB db *sql.DB
} }
func NewOpsRepository(db *sql.DB) service.OpsRepository { const insertOpsErrorLogSQL = `
return &opsRepository{db: db}
}
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if input == nil {
return 0, fmt.Errorf("nil input")
}
q := `
INSERT INTO ops_error_logs ( INSERT INTO ops_error_logs (
request_id, request_id,
client_request_id, client_request_id,
...@@ -70,12 +58,77 @@ INSERT INTO ops_error_logs ( ...@@ -70,12 +58,77 @@ INSERT INTO ops_error_logs (
created_at created_at
) VALUES ( ) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
) RETURNING id` )`
func NewOpsRepository(db *sql.DB) service.OpsRepository {
return &opsRepository{db: db}
}
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if input == nil {
return 0, fmt.Errorf("nil input")
}
var id int64 var id int64
err := r.db.QueryRowContext( err := r.db.QueryRowContext(
ctx, ctx,
q, insertOpsErrorLogSQL+" RETURNING id",
opsInsertErrorLogArgs(input)...,
).Scan(&id)
if err != nil {
return 0, err
}
return id, nil
}
func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if len(inputs) == 0 {
return 0, nil
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL)
if err != nil {
return 0, err
}
defer func() {
_ = stmt.Close()
}()
var inserted int64
for _, input := range inputs {
if input == nil {
continue
}
if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil {
return inserted, err
}
inserted++
}
if err = tx.Commit(); err != nil {
return inserted, err
}
return inserted, nil
}
func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
return []any{
opsNullString(input.RequestID), opsNullString(input.RequestID),
opsNullString(input.ClientRequestID), opsNullString(input.ClientRequestID),
opsNullInt64(input.UserID), opsNullInt64(input.UserID),
...@@ -114,11 +167,7 @@ INSERT INTO ops_error_logs ( ...@@ -114,11 +167,7 @@ INSERT INTO ops_error_logs (
input.IsRetryable, input.IsRetryable,
input.RetryCount, input.RetryCount,
input.CreatedAt, input.CreatedAt,
).Scan(&id)
if err != nil {
return 0, err
} }
return id, nil
} }
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) { func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
......
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
repo := NewOpsRepository(integrationDB).(*opsRepository)
now := time.Now().UTC()
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
{
RequestID: "batch-ops-1",
ErrorPhase: "upstream",
ErrorType: "upstream_error",
Severity: "error",
StatusCode: 429,
ErrorMessage: "rate limited",
CreatedAt: now,
},
{
RequestID: "batch-ops-2",
ErrorPhase: "internal",
ErrorType: "api_error",
Severity: "error",
StatusCode: 500,
ErrorMessage: "internal error",
CreatedAt: now.Add(time.Millisecond),
},
})
require.NoError(t, err)
require.EqualValues(t, 2, inserted)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
require.Equal(t, 2, count)
}
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
accountID := int64(12345)
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
require.Equal(t, 1, count)
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
require.Equal(t, 2, count)
}
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
accountID := int64(67890)
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
require.Equal(t, 2, count)
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
...@@ -12,6 +13,8 @@ type schedulerOutboxRepository struct { ...@@ -12,6 +13,8 @@ type schedulerOutboxRepository struct {
db *sql.DB db *sql.DB
} }
const schedulerOutboxDedupWindow = time.Second
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository { func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
return &schedulerOutboxRepository{db: db} return &schedulerOutboxRepository{db: db}
} }
...@@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str ...@@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str
} }
payloadArg = encoded payloadArg = encoded
} }
_, err := exec.ExecContext(ctx, ` query := `
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload) INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
`, eventType, accountID, groupID, payloadArg) `
args := []any{eventType, accountID, groupID, payloadArg}
if schedulerOutboxEventSupportsDedup(eventType) {
query = `
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
SELECT $1, $2, $3, $4
WHERE NOT EXISTS (
SELECT 1
FROM scheduler_outbox
WHERE event_type = $1
AND account_id IS NOT DISTINCT FROM $2
AND group_id IS NOT DISTINCT FROM $3
AND created_at >= NOW() - make_interval(secs => $5)
)
`
args = append(args, schedulerOutboxDedupWindow.Seconds())
}
_, err := exec.ExecContext(ctx, query, args...)
return err return err
} }
func schedulerOutboxEventSupportsDedup(eventType string) bool {
switch eventType {
case service.SchedulerOutboxEventAccountChanged,
service.SchedulerOutboxEventGroupChanged,
service.SchedulerOutboxEventFullRebuild:
return true
default:
return false
}
}
...@@ -456,6 +456,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -456,6 +456,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
subscriptions.POST("/assign", h.Admin.Subscription.Assign) subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign) subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend) subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke) subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
} }
......
...@@ -71,15 +71,8 @@ func RegisterGatewayRoutes( ...@@ -71,15 +71,8 @@ func RegisterGatewayRoutes(
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", h.OpenAIGateway.Responses)
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses) gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 // OpenAI Chat Completions API
gateway.POST("/chat/completions", func(c *gin.Context) { gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
},
})
})
} }
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
...@@ -100,6 +93,8 @@ func RegisterGatewayRoutes( ...@@ -100,6 +93,8 @@ func RegisterGatewayRoutes(
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API(不带v1前缀的别名)
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
// Antigravity 模型列表 // Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
......
...@@ -50,11 +50,18 @@ type TestEvent struct { ...@@ -50,11 +50,18 @@ type TestEvent struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"` Code string `json:"code,omitempty"`
ImageURL string `json:"image_url,omitempty"`
MimeType string `json:"mime_type,omitempty"`
Data any `json:"data,omitempty"` Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"` Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
} }
const (
defaultGeminiTextTestPrompt = "hi"
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
)
// AccountTestService handles account testing operations // AccountTestService handles account testing operations
type AccountTestService struct { type AccountTestService struct {
accountRepo AccountRepository accountRepo AccountRepository
...@@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) { ...@@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request // TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs // All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel // modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error { func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Get account // Get account
...@@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int ...@@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
} }
if account.IsGemini() { if account.IsGemini() {
return s.testGeminiAccountConnection(c, account, modelID) return s.testGeminiAccountConnection(c, account, modelID, prompt)
} }
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
return s.routeAntigravityTest(c, account, modelID) return s.routeAntigravityTest(c, account, modelID, prompt)
} }
if account.Platform == PlatformSora { if account.Platform == PlatformSora {
...@@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
} }
// testGeminiAccountConnection tests a Gemini account's connection // testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error { func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Determine the model to use // Determine the model to use
...@@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
c.Writer.Flush() c.Writer.Flush()
// Create test payload (Gemini format) // Create test payload (Gemini format)
payload := createGeminiTestPayload() payload := createGeminiTestPayload(testModelID, prompt)
// Build request based on account type // Build request based on account type
var req *http.Request var req *http.Request
...@@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string { ...@@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string {
// routeAntigravityTest 路由 Antigravity 账号的测试请求。 // routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。 // APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error { func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
if strings.HasPrefix(modelID, "gemini-") { if strings.HasPrefix(modelID, "gemini-") {
return s.testGeminiAccountConnection(c, account, modelID) return s.testGeminiAccountConnection(c, account, modelID, prompt)
} }
return s.testClaudeAccountConnection(c, account, modelID) return s.testClaudeAccountConnection(c, account, modelID)
} }
...@@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT ...@@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
return req, nil return req, nil
} }
// createGeminiTestPayload creates a minimal test payload for Gemini API // createGeminiTestPayload creates a minimal test payload for Gemini API.
func createGeminiTestPayload() []byte { // Image models use the image-generation path so the frontend can preview the returned image.
func createGeminiTestPayload(modelID string, prompt string) []byte {
if isImageGenerationModel(modelID) {
imagePrompt := strings.TrimSpace(prompt)
if imagePrompt == "" {
imagePrompt = defaultGeminiImageTestPrompt
}
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": imagePrompt},
},
},
},
"generationConfig": map[string]any{
"responseModalities": []string{"TEXT", "IMAGE"},
"imageConfig": map[string]any{
"aspectRatio": "1:1",
},
},
}
bytes, _ := json.Marshal(payload)
return bytes
}
textPrompt := strings.TrimSpace(prompt)
if textPrompt == "" {
textPrompt = defaultGeminiTextTestPrompt
}
payload := map[string]any{ payload := map[string]any{
"contents": []map[string]any{ "contents": []map[string]any{
{ {
"role": "user", "role": "user",
"parts": []map[string]any{ "parts": []map[string]any{
{"text": "hi"}, {"text": textPrompt},
}, },
}, },
}, },
...@@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) ...@@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
if text, ok := partMap["text"].(string); ok && text != "" { if text, ok := partMap["text"].(string); ok && text != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: text}) s.sendEvent(c, TestEvent{Type: "content", Text: text})
} }
if inlineData, ok := partMap["inlineData"].(map[string]any); ok {
mimeType, _ := inlineData["mimeType"].(string)
data, _ := inlineData["data"].(string)
if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" {
s.sendEvent(c, TestEvent{
Type: "image",
ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data),
MimeType: mimeType,
})
}
}
} }
} }
} }
...@@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in ...@@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w) ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx) ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID) testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
finishedAt := time.Now() finishedAt := time.Now()
body := w.Body.String() body := w.Body.String()
......
//go:build unit
package service
import (
"encoding/json"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
t.Parallel()
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
var parsed struct {
Contents []struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"contents"`
GenerationConfig struct {
ResponseModalities []string `json:"responseModalities"`
ImageConfig struct {
AspectRatio string `json:"aspectRatio"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
require.NoError(t, json.Unmarshal(payload, &parsed))
require.Len(t, parsed.Contents, 1)
require.Len(t, parsed.Contents[0].Parts, 1)
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
}
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
err := svc.processGeminiStream(ctx, stream)
require.NoError(t, err)
body := recorder.Body.String()
require.Contains(t, body, "\"type\":\"content\"")
require.Contains(t, body, "\"text\":\"ok\"")
require.Contains(t, body, "\"type\":\"image\"")
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
require.Contains(t, body, "\"mime_type\":\"image/png\"")
}
...@@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou ...@@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
} }
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 { if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
mergeAccountExtra(account, updates) mergeAccountExtra(account, updates)
if resetAt != nil {
account.RateLimitResetAt = resetAt
}
if usage.UpdatedAt == nil { if usage.UpdatedAt == nil {
usage.UpdatedAt = &now usage.UpdatedAt = &now
} }
...@@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no ...@@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
return true return true
} }
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) { func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
if account == nil || !account.IsOAuth() { if account == nil || !account.IsOAuth() {
return nil, nil return nil, nil, nil
} }
accessToken := account.GetOpenAIAccessToken() accessToken := account.GetOpenAIAccessToken()
if accessToken == "" { if accessToken == "" {
return nil, fmt.Errorf("no access token available") return nil, nil, fmt.Errorf("no access token available")
} }
modelID := openaipkg.DefaultTestModel modelID := openaipkg.DefaultTestModel
payload := createOpenAITestPayload(modelID, true) payload := createOpenAITestPayload(modelID, true)
payloadBytes, err := json.Marshal(payload) payloadBytes, err := json.Marshal(payload)
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal openai probe payload: %w", err) return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
} }
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second) reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel() defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes)) req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
if err != nil { if err != nil {
return nil, fmt.Errorf("create openai probe request: %w", err) return nil, nil, fmt.Errorf("create openai probe request: %w", err)
} }
req.Host = "chatgpt.com" req.Host = "chatgpt.com"
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
...@@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco ...@@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
ResponseHeaderTimeout: 10 * time.Second, ResponseHeaderTimeout: 10 * time.Second,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("build openai probe client: %w", err) return nil, nil, fmt.Errorf("build openai probe client: %w", err)
} }
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("openai codex probe request failed: %w", err) return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
updates, err := extractOpenAICodexProbeUpdates(resp) updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if len(updates) > 0 { if len(updates) > 0 || resetAt != nil {
go func(accountID int64, updates map[string]any) { s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
return updates, resetAt, nil
}
return nil, nil, nil
}
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
if s == nil || s.accountRepo == nil || accountID <= 0 {
return
}
if len(updates) == 0 && resetAt == nil {
return
}
go func() {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel() defer updateCancel()
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}(account.ID, updates)
return updates, nil
} }
return nil, nil if resetAt != nil {
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
}
}()
} }
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
if resp == nil { if resp == nil {
return nil, nil return nil, nil, nil
} }
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) baseTime := time.Now()
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
if len(updates) > 0 { if len(updates) > 0 {
return updates, nil return updates, resetAt, nil
} }
return nil, resetAt, nil
} }
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
} }
return nil, nil return nil, nil, nil
}
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
return updates, err
} }
func mergeAccountExtra(account *Account, updates map[string]any) { func mergeAccountExtra(account *Account, updates map[string]any) {
......
package service package service
import ( import (
"context"
"net/http" "net/http"
"testing" "testing"
"time" "time"
) )
type accountUsageCodexProbeRepo struct {
stubOpenAIAccountRepo
updateExtraCh chan map[string]any
rateLimitCh chan time.Time
}
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
if r.updateExtraCh != nil {
copied := make(map[string]any, len(updates))
for k, v := range updates {
copied[k] = v
}
r.updateExtraCh <- copied
}
return nil
}
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
if r.rateLimitCh != nil {
r.rateLimitCh <- resetAt
}
return nil
}
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) { func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
t.Parallel() t.Parallel()
...@@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) ...@@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
t.Fatalf("codex_7d_used_percent = %v, want 100", got) t.Fatalf("codex_7d_used_percent = %v, want 100", got)
} }
} }
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "604800")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
headers.Set("x-codex-secondary-window-minutes", "300")
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
if err != nil {
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
}
if len(updates) == 0 {
t.Fatal("expected codex probe updates from 429 headers")
}
if resetAt == nil {
t.Fatal("expected resetAt from exhausted codex headers")
}
}
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
t.Parallel()
repo := &accountUsageCodexProbeRepo{
updateExtraCh: make(chan map[string]any, 1),
rateLimitCh: make(chan time.Time, 1),
}
svc := &AccountUsageService{accountRepo: repo}
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
}, &resetAt)
select {
case updates := <-repo.updateExtraCh:
if got := updates["codex_7d_used_percent"]; got != 100.0 {
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe extra persistence timed out")
}
select {
case got := <-repo.rateLimitCh:
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe rate limit persistence timed out")
}
}
...@@ -2164,6 +2164,112 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -2164,6 +2164,112 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
} }
} }
// Gemini 原生请求中的 thoughtSignature 可能来自旧上下文/旧账号,触发上游严格校验后返回
// "Corrupted thought signature."。检测到此类 400 时,将 thoughtSignature 清理为 dummy 值后重试一次。
signatureCheckBody := respBody
if unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && len(unwrapped) > 0 {
signatureCheckBody = unwrapped
}
if resp.StatusCode == http.StatusBadRequest &&
s.settingService != nil &&
s.settingService.IsSignatureRectifierEnabled(ctx) &&
isSignatureRelatedError(signatureCheckBody) &&
bytes.Contains(injectedBody, []byte(`"thoughtSignature"`)) {
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(signatureCheckBody)))
upstreamDetail := s.getUpstreamErrorDetail(signatureCheckBody)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
if wrapErr == nil {
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: upstreamAction,
body: retryWrappedBody,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession,
groupID: 0,
sessionHash: "",
})
if retryErr == nil {
retryResp := retryResult.resp
if retryResp.StatusCode < 400 {
resp = retryResp
} else {
retryRespBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
retryOpsBody := retryRespBody
if retryUnwrapped, unwrapErr := s.unwrapV1InternalResponse(retryRespBody); unwrapErr == nil && len(retryUnwrapped) > 0 {
retryOpsBody = retryUnwrapped
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: "signature_retry",
Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(retryOpsBody))),
Detail: s.getUpstreamErrorDetail(retryOpsBody),
})
respBody = retryRespBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
}
contentType = resp.Header.Get("Content-Type")
}
} else {
if switchErr, ok := IsAntigravityAccountSwitchError(retryErr); ok {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: http.StatusServiceUnavailable,
Kind: "failover",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
})
return nil, &UpstreamFailoverError{
StatusCode: http.StatusServiceUnavailable,
ForceCacheBilling: switchErr.IsStickySession,
}
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
})
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry request failed: %v", account.ID, retryErr)
}
} else {
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry wrap failed: %v", account.ID, wrapErr)
}
}
// fallback 成功:继续按正常响应处理 // fallback 成功:继续按正常响应处理
if resp.StatusCode < 400 { if resp.StatusCode < 400 {
goto handleSuccess goto handleSuccess
......
...@@ -134,6 +134,47 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, ...@@ -134,6 +134,47 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
return s.resp, s.err return s.resp, s.err
} }
type queuedHTTPUpstreamStub struct {
responses []*http.Response
errors []error
requestBodies [][]byte
callCount int
onCall func(*http.Request, *queuedHTTPUpstreamStub)
}
func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
if req != nil && req.Body != nil {
body, _ := io.ReadAll(req.Body)
s.requestBodies = append(s.requestBodies, body)
req.Body = io.NopCloser(bytes.NewReader(body))
} else {
s.requestBodies = append(s.requestBodies, nil)
}
idx := s.callCount
s.callCount++
if s.onCall != nil {
s.onCall(req, s)
}
var resp *http.Response
if idx < len(s.responses) {
resp = s.responses[idx]
}
var err error
if idx < len(s.errors) {
err = s.errors[idx]
}
if resp == nil && err == nil {
return nil, errors.New("unexpected upstream call")
}
return resp, err
}
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, concurrency)
}
type antigravitySettingRepoStub struct{} type antigravitySettingRepoStub struct{}
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
...@@ -556,6 +597,177 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing ...@@ -556,6 +597,177 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
require.Equal(t, mappedModel, result.Model) require.Equal(t, mappedModel, result.Model)
} }
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
{"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
c.Request = req
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req-sig-1"},
},
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req-sig-2"},
},
Body: io.NopCloser(bytes.NewReader(secondRespBody)),
},
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
const originalModel = "gemini-3.1-pro-preview"
const mappedModel = "gemini-3.1-pro-high"
account := &Account{
ID: 7,
Name: "acc-gemini-signature",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
originalModel: mappedModel,
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
firstReq := string(upstream.requestBodies[0])
secondReq := string(upstream.requestBodies[1])
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`)
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`)
require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`)
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`)
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`)
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.NotEmpty(t, events)
require.Equal(t, "signature_error", events[0].Kind)
}
func TestAntigravityGatewayService_ForwardGemini_SignatureRetryPropagatesFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
c.Request = req
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
const originalModel = "gemini-3.1-pro-preview"
const mappedModel = "gemini-3.1-pro-high"
account := &Account{
ID: 8,
Name: "acc-gemini-signature-failover",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
originalModel: mappedModel,
},
},
}
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req-sig-failover-1"},
},
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
},
},
onCall: func(_ *http.Request, stub *queuedHTTPUpstreamStub) {
if stub.callCount != 1 {
return
}
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account.Extra = map[string]any{
modelRateLimitsKey: map[string]any{
mappedModel: map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
}
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, true)
require.Nil(t, result)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "signature retry should propagate failover instead of falling back to the original 400")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling)
require.Len(t, upstream.requestBodies, 1, "retry should stop at preflight failover and not issue a second upstream request")
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 2)
require.Equal(t, "signature_error", events[0].Kind)
require.Equal(t, "failover", events[1].Kind)
}
// TestStreamUpstreamResponse_UsageAndFirstToken // TestStreamUpstreamResponse_UsageAndFirstToken
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 // 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
...@@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 { ...@@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
return d.Usage7d return d.Usage7d
} }
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
// It is intentionally small so repositories can return it from a single SQL statement.
type APIKeyQuotaUsageState struct {
QuotaUsed float64
Quota float64
Key string
Status string
}
// APIKeyCache defines cache operations for API key service // APIKeyCache defines cache operations for API key service
type APIKeyCache interface { type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
...@@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos ...@@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
return nil return nil
} }
type quotaStateReader interface {
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
}
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
if err != nil {
return fmt.Errorf("increment quota used: %w", err)
}
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
s.InvalidateAuthCacheByKey(ctx, state.Key)
}
return nil
}
// Use repository to atomically increment quota_used // Use repository to atomically increment quota_used
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
if err != nil { if err != nil {
......
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type quotaStateRepoStub struct {
quotaBaseAPIKeyRepoStub
stateCalls int
state *APIKeyQuotaUsageState
stateErr error
}
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
s.stateCalls++
if s.stateErr != nil {
return nil, s.stateErr
}
if s.state == nil {
return nil, nil
}
out := *s.state
return &out, nil
}
type quotaStateCacheStub struct {
deleteAuthKeys []string
}
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
return 0, nil
}
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
return nil
}
type quotaBaseAPIKeyRepoStub struct {
getByIDCalls int
}
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
panic("unexpected Create call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
s.getByIDCalls++
return nil, nil
}
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
panic("unexpected Update call")
}
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
panic("unexpected IncrementQuotaUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected UpdateLastUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
repo := &quotaStateRepoStub{
state: &APIKeyQuotaUsageState{
QuotaUsed: 12,
Quota: 10,
Key: "sk-test-quota",
Status: StatusAPIKeyQuotaExhausted,
},
}
cache := &quotaStateCacheStub{}
svc := &APIKeyService{
apiKeyRepo: repo,
cache: cache,
}
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
require.NoError(t, err)
require.Equal(t, 1, repo.stateCalls)
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment