Commit d8aff3a7 authored by ius's avatar ius
Browse files

Merge origin/main into fix/account-extra-scheduler-pressure-20260311

parents 2b30e3b6 c0110cb5
package gemini
import "testing"
func TestDefaultModels_ContainsImageModels(t *testing.T) {
t.Parallel()
models := DefaultModels()
byName := make(map[string]Model, len(models))
for _, model := range models {
byName[model.Name] = model
}
required := []string{
"models/gemini-2.5-flash-image",
"models/gemini-3.1-flash-image",
}
for _, name := range required {
model, ok := byName[name]
if !ok {
t.Fatalf("expected fallback model %q to exist", name)
}
if len(model.SupportedGenerationMethods) == 0 {
t.Fatalf("expected fallback model %q to advertise generation methods", name)
}
}
}
......@@ -13,10 +13,12 @@ type Model struct {
var DefaultModels = []Model{
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.
......
package geminicli
import "testing"
func TestDefaultModels_ContainsImageModels(t *testing.T) {
t.Parallel()
byID := make(map[string]Model, len(DefaultModels))
for _, model := range DefaultModels {
byID[model.ID] = model
}
required := []string{
"gemini-2.5-flash-image",
"gemini-3.1-flash-image",
}
for _, id := range required {
if _, ok := byID[id]; !ok {
t.Fatalf("expected curated Gemini model %q to exist", id)
}
}
}
......@@ -626,29 +626,6 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
}
func (r *accountRepository) patchSchedulerAccountExtra(ctx context.Context, accountID int64, updates map[string]any) {
if r == nil || r.schedulerCache == nil || accountID <= 0 || len(updates) == 0 {
return
}
account, err := r.schedulerCache.GetAccount(ctx, accountID)
if err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] patch account extra read failed: id=%d err=%v", accountID, err)
return
}
if account == nil {
return
}
if account.Extra == nil {
account.Extra = make(map[string]any, len(updates))
}
for key, value := range updates {
account.Extra[key] = value
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] patch account extra write failed: id=%d err=%v", accountID, err)
}
}
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
return
......@@ -1221,15 +1198,15 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
if affected == 0 {
return service.ErrAccountNotFound
}
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
}
} else {
// 观测型 extra 字段不需要触发 bucket 重建,但尽量把单账号缓存补到最新,
// 让 sticky session / GetAccount 命中缓存时也能读到最新快照。
r.patchSchedulerAccountExtra(ctx, id, updates)
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
r.syncSchedulerAccountSnapshot(ctx, id)
}
return nil
}
......@@ -1239,9 +1216,10 @@ func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
return false
}
for key := range updates {
if !isSchedulerNeutralExtraKey(key) {
return true
if isSchedulerNeutralExtraKey(key) {
continue
}
return true
}
return false
}
......@@ -1262,6 +1240,82 @@ func isSchedulerNeutralExtraKey(key string) bool {
return false
}
func shouldSyncSchedulerSnapshotForExtraUpdates(updates map[string]any) bool {
return codexExtraIndicatesRateLimit(updates, "7d") || codexExtraIndicatesRateLimit(updates, "5h")
}
func codexExtraIndicatesRateLimit(updates map[string]any, window string) bool {
if len(updates) == 0 {
return false
}
usedValue, ok := updates["codex_"+window+"_used_percent"]
if !ok || !extraValueIndicatesExhausted(usedValue) {
return false
}
return extraValueHasResetMarker(updates["codex_"+window+"_reset_at"]) ||
extraValueHasPositiveNumber(updates["codex_"+window+"_reset_after_seconds"])
}
func extraValueIndicatesExhausted(value any) bool {
number, ok := extraValueToFloat64(value)
return ok && number >= 100-1e-9
}
func extraValueHasPositiveNumber(value any) bool {
number, ok := extraValueToFloat64(value)
return ok && number > 0
}
func extraValueHasResetMarker(value any) bool {
switch v := value.(type) {
case string:
return strings.TrimSpace(v) != ""
case time.Time:
return !v.IsZero()
case *time.Time:
return v != nil && !v.IsZero()
default:
return false
}
}
func extraValueToFloat64(value any) (float64, bool) {
switch v := value.(type) {
case float64:
return v, true
case float32:
return float64(v), true
case int:
return float64(v), true
case int8:
return float64(v), true
case int16:
return float64(v), true
case int32:
return float64(v), true
case int64:
return float64(v), true
case uint:
return float64(v), true
case uint8:
return float64(v), true
case uint16:
return float64(v), true
case uint32:
return float64(v), true
case uint64:
return float64(v), true
case json.Number:
parsed, err := v.Float64()
return parsed, err == nil
case string:
parsed, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
return parsed, err == nil
default:
return 0, false
}
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
......
......@@ -633,7 +633,7 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
s.Require().Equal("val", got.Extra["key"])
}
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndPatchesCache() {
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra-neutral",
Platform: service.PlatformOpenAI,
......@@ -644,6 +644,7 @@ func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndPatches
account.ID: {
ID: account.ID,
Platform: account.Platform,
Status: service.StatusDisabled,
Extra: map[string]any{
"codex_usage_updated_at": "old",
},
......@@ -670,25 +671,56 @@ func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndPatches
s.Require().Zero(outboxCount)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().NotNil(cacheRecorder.accounts[account.ID])
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
}
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra-codex-exhausted",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Extra: map[string]any{},
})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
s.Require().NoError(err)
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
"codex_7d_reset_after_seconds": 86400,
}))
var count int
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
s.Require().NoError(err)
s.Require().Equal(0, count)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
}
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra-mixed",
Platform: service.PlatformAntigravity,
Extra: map[string]any{},
})
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
s.Require().NoError(err)
updates := map[string]any{
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
"mixed_scheduling": true,
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
}
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
}))
var outboxCount int
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
s.Require().Equal(1, outboxCount)
var count int
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
s.Require().NoError(err)
s.Require().Equal(1, count)
}
// --- GetByCRSAccountID ---
......
......@@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
return updated.QuotaUsed, nil
}
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
// as quota_exhausted, and returns the latest quota state in one round trip.
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
query := `
UPDATE api_keys
SET
quota_used = quota_used + $1,
status = CASE
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
ELSE status
END,
updated_at = NOW()
WHERE id = $3 AND deleted_at IS NULL
RETURNING quota_used, quota, key, status
`
state := &service.APIKeyQuotaUsageState{}
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
if err == sql.ErrNoRows {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return state, nil
}
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
......
......@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
user := s.mustCreateUser("quota-state@test.com")
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
key.Quota = 3
key.QuotaUsed = 1
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
s.Require().NotNil(state)
s.Require().Equal(3.5, state.QuotaUsed)
s.Require().Equal(3.0, state.Quota)
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
s.Require().Equal(key.Key, state.Key)
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(3.5, got.QuotaUsed)
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
}
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
......
......@@ -16,19 +16,7 @@ type opsRepository struct {
db *sql.DB
}
func NewOpsRepository(db *sql.DB) service.OpsRepository {
return &opsRepository{db: db}
}
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if input == nil {
return 0, fmt.Errorf("nil input")
}
q := `
const insertOpsErrorLogSQL = `
INSERT INTO ops_error_logs (
request_id,
client_request_id,
......@@ -70,12 +58,77 @@ INSERT INTO ops_error_logs (
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
) RETURNING id`
)`
func NewOpsRepository(db *sql.DB) service.OpsRepository {
return &opsRepository{db: db}
}
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if input == nil {
return 0, fmt.Errorf("nil input")
}
var id int64
err := r.db.QueryRowContext(
ctx,
q,
insertOpsErrorLogSQL+" RETURNING id",
opsInsertErrorLogArgs(input)...,
).Scan(&id)
if err != nil {
return 0, err
}
return id, nil
}
func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if len(inputs) == 0 {
return 0, nil
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL)
if err != nil {
return 0, err
}
defer func() {
_ = stmt.Close()
}()
var inserted int64
for _, input := range inputs {
if input == nil {
continue
}
if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil {
return inserted, err
}
inserted++
}
if err = tx.Commit(); err != nil {
return inserted, err
}
return inserted, nil
}
func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
return []any{
opsNullString(input.RequestID),
opsNullString(input.ClientRequestID),
opsNullInt64(input.UserID),
......@@ -114,11 +167,7 @@ INSERT INTO ops_error_logs (
input.IsRetryable,
input.RetryCount,
input.CreatedAt,
).Scan(&id)
if err != nil {
return 0, err
}
return id, nil
}
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
......
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
repo := NewOpsRepository(integrationDB).(*opsRepository)
now := time.Now().UTC()
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
{
RequestID: "batch-ops-1",
ErrorPhase: "upstream",
ErrorType: "upstream_error",
Severity: "error",
StatusCode: 429,
ErrorMessage: "rate limited",
CreatedAt: now,
},
{
RequestID: "batch-ops-2",
ErrorPhase: "internal",
ErrorType: "api_error",
Severity: "error",
StatusCode: 500,
ErrorMessage: "internal error",
CreatedAt: now.Add(time.Millisecond),
},
})
require.NoError(t, err)
require.EqualValues(t, 2, inserted)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
require.Equal(t, 2, count)
}
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
accountID := int64(12345)
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
require.Equal(t, 1, count)
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
require.Equal(t, 2, count)
}
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
accountID := int64(67890)
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
require.Equal(t, 2, count)
}
......@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
......@@ -12,6 +13,8 @@ type schedulerOutboxRepository struct {
db *sql.DB
}
const schedulerOutboxDedupWindow = time.Second
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
return &schedulerOutboxRepository{db: db}
}
......@@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str
}
payloadArg = encoded
}
_, err := exec.ExecContext(ctx, `
query := `
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
VALUES ($1, $2, $3, $4)
`, eventType, accountID, groupID, payloadArg)
`
args := []any{eventType, accountID, groupID, payloadArg}
if schedulerOutboxEventSupportsDedup(eventType) {
query = `
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
SELECT $1, $2, $3, $4
WHERE NOT EXISTS (
SELECT 1
FROM scheduler_outbox
WHERE event_type = $1
AND account_id IS NOT DISTINCT FROM $2
AND group_id IS NOT DISTINCT FROM $3
AND created_at >= NOW() - make_interval(secs => $5)
)
`
args = append(args, schedulerOutboxDedupWindow.Seconds())
}
_, err := exec.ExecContext(ctx, query, args...)
return err
}
func schedulerOutboxEventSupportsDedup(eventType string) bool {
switch eventType {
case service.SchedulerOutboxEventAccountChanged,
service.SchedulerOutboxEventGroupChanged,
service.SchedulerOutboxEventFullRebuild:
return true
default:
return false
}
}
......@@ -456,6 +456,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
......
......@@ -71,15 +71,8 @@ func RegisterGatewayRoutes(
gateway.POST("/responses", h.OpenAIGateway.Responses)
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
gateway.POST("/chat/completions", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
},
})
})
// OpenAI Chat Completions API
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
......@@ -100,6 +93,8 @@ func RegisterGatewayRoutes(
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API(不带v1前缀的别名)
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
......
......@@ -45,16 +45,23 @@ const (
// TestEvent represents a SSE event for account testing
type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"`
Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"`
ImageURL string `json:"image_url,omitempty"`
MimeType string `json:"mime_type,omitempty"`
Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
const (
defaultGeminiTextTestPrompt = "hi"
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
)
// AccountTestService handles account testing operations
type AccountTestService struct {
accountRepo AccountRepository
......@@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
ctx := c.Request.Context()
// Get account
......@@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
}
if account.IsGemini() {
return s.testGeminiAccountConnection(c, account, modelID)
return s.testGeminiAccountConnection(c, account, modelID, prompt)
}
if account.Platform == PlatformAntigravity {
return s.routeAntigravityTest(c, account, modelID)
return s.routeAntigravityTest(c, account, modelID, prompt)
}
if account.Platform == PlatformSora {
......@@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
}
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
// Determine the model to use
......@@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
c.Writer.Flush()
// Create test payload (Gemini format)
payload := createGeminiTestPayload()
payload := createGeminiTestPayload(testModelID, prompt)
// Build request based on account type
var req *http.Request
......@@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string {
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
if account.Type == AccountTypeAPIKey {
if strings.HasPrefix(modelID, "gemini-") {
return s.testGeminiAccountConnection(c, account, modelID)
return s.testGeminiAccountConnection(c, account, modelID, prompt)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
......@@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
return req, nil
}
// createGeminiTestPayload creates a minimal test payload for Gemini API
func createGeminiTestPayload() []byte {
// createGeminiTestPayload creates a minimal test payload for Gemini API.
// Image models use the image-generation path so the frontend can preview the returned image.
func createGeminiTestPayload(modelID string, prompt string) []byte {
if isImageGenerationModel(modelID) {
imagePrompt := strings.TrimSpace(prompt)
if imagePrompt == "" {
imagePrompt = defaultGeminiImageTestPrompt
}
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": imagePrompt},
},
},
},
"generationConfig": map[string]any{
"responseModalities": []string{"TEXT", "IMAGE"},
"imageConfig": map[string]any{
"aspectRatio": "1:1",
},
},
}
bytes, _ := json.Marshal(payload)
return bytes
}
textPrompt := strings.TrimSpace(prompt)
if textPrompt == "" {
textPrompt = defaultGeminiTextTestPrompt
}
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": "hi"},
{"text": textPrompt},
},
},
},
......@@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
if text, ok := partMap["text"].(string); ok && text != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: text})
}
if inlineData, ok := partMap["inlineData"].(map[string]any); ok {
mimeType, _ := inlineData["mimeType"].(string)
data, _ := inlineData["data"].(string)
if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" {
s.sendEvent(c, TestEvent{
Type: "image",
ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data),
MimeType: mimeType,
})
}
}
}
}
}
......@@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
finishedAt := time.Now()
body := w.Body.String()
......
//go:build unit
package service
import (
"encoding/json"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
t.Parallel()
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
var parsed struct {
Contents []struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"contents"`
GenerationConfig struct {
ResponseModalities []string `json:"responseModalities"`
ImageConfig struct {
AspectRatio string `json:"aspectRatio"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
require.NoError(t, json.Unmarshal(payload, &parsed))
require.Len(t, parsed.Contents, 1)
require.Len(t, parsed.Contents[0].Parts, 1)
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
}
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
err := svc.processGeminiStream(ctx, stream)
require.NoError(t, err)
body := recorder.Body.String()
require.Contains(t, body, "\"type\":\"content\"")
require.Contains(t, body, "\"text\":\"ok\"")
require.Contains(t, body, "\"type\":\"image\"")
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
require.Contains(t, body, "\"mime_type\":\"image/png\"")
}
......@@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
}
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
mergeAccountExtra(account, updates)
if resetAt != nil {
account.RateLimitResetAt = resetAt
}
if usage.UpdatedAt == nil {
usage.UpdatedAt = &now
}
......@@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
return true
}
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
if account == nil || !account.IsOAuth() {
return nil, nil
return nil, nil, nil
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return nil, fmt.Errorf("no access token available")
return nil, nil, fmt.Errorf("no access token available")
}
modelID := openaipkg.DefaultTestModel
payload := createOpenAITestPayload(modelID, true)
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
}
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("create openai probe request: %w", err)
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
}
req.Host = "chatgpt.com"
req.Header.Set("Content-Type", "application/json")
......@@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
ResponseHeaderTimeout: 10 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("build openai probe client: %w", err)
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
updates, err := extractOpenAICodexProbeUpdates(resp)
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
if err != nil {
return nil, err
return nil, nil, err
}
if len(updates) > 0 {
go func(accountID int64, updates map[string]any) {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}(account.ID, updates)
return updates, nil
if len(updates) > 0 || resetAt != nil {
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
return updates, resetAt, nil
}
return nil, nil
return nil, nil, nil
}
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
if s == nil || s.accountRepo == nil || accountID <= 0 {
return
}
if len(updates) == 0 && resetAt == nil {
return
}
go func() {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}
if resetAt != nil {
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
}
}()
}
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
if resp == nil {
return nil, nil
return nil, nil, nil
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
baseTime := time.Now()
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
if len(updates) > 0 {
return updates, nil
return updates, resetAt, nil
}
return nil, resetAt, nil
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
}
return nil, nil
return nil, nil, nil
}
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
return updates, err
}
func mergeAccountExtra(account *Account, updates map[string]any) {
......
package service
import (
"context"
"net/http"
"testing"
"time"
)
type accountUsageCodexProbeRepo struct {
stubOpenAIAccountRepo
updateExtraCh chan map[string]any
rateLimitCh chan time.Time
}
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
if r.updateExtraCh != nil {
copied := make(map[string]any, len(updates))
for k, v := range updates {
copied[k] = v
}
r.updateExtraCh <- copied
}
return nil
}
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
if r.rateLimitCh != nil {
r.rateLimitCh <- resetAt
}
return nil
}
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
t.Parallel()
......@@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
}
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "604800")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
headers.Set("x-codex-secondary-window-minutes", "300")
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
if err != nil {
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
}
if len(updates) == 0 {
t.Fatal("expected codex probe updates from 429 headers")
}
if resetAt == nil {
t.Fatal("expected resetAt from exhausted codex headers")
}
}
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
t.Parallel()
repo := &accountUsageCodexProbeRepo{
updateExtraCh: make(chan map[string]any, 1),
rateLimitCh: make(chan time.Time, 1),
}
svc := &AccountUsageService{accountRepo: repo}
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
}, &resetAt)
select {
case updates := <-repo.updateExtraCh:
if got := updates["codex_7d_used_percent"]; got != 100.0 {
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe extra persistence timed out")
}
select {
case got := <-repo.rateLimitCh:
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe rate limit persistence timed out")
}
}
......@@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"strconv"
"strings"
"sync"
"time"
......@@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
return d.Usage7d
}
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
// It is intentionally small so repositories can return it from a single SQL statement.
type APIKeyQuotaUsageState struct {
QuotaUsed float64
Quota float64
Key string
Status string
}
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
......@@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
return nil
}
type quotaStateReader interface {
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
}
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
if err != nil {
return fmt.Errorf("increment quota used: %w", err)
}
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
s.InvalidateAuthCacheByKey(ctx, state.Key)
}
return nil
}
// Use repository to atomically increment quota_used
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
if err != nil {
......
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type quotaStateRepoStub struct {
quotaBaseAPIKeyRepoStub
stateCalls int
state *APIKeyQuotaUsageState
stateErr error
}
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
s.stateCalls++
if s.stateErr != nil {
return nil, s.stateErr
}
if s.state == nil {
return nil, nil
}
out := *s.state
return &out, nil
}
type quotaStateCacheStub struct {
deleteAuthKeys []string
}
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
return 0, nil
}
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
return nil
}
type quotaBaseAPIKeyRepoStub struct {
getByIDCalls int
}
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
panic("unexpected Create call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
s.getByIDCalls++
return nil, nil
}
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
panic("unexpected Update call")
}
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
panic("unexpected IncrementQuotaUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected UpdateLastUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
repo := &quotaStateRepoStub{
state: &APIKeyQuotaUsageState{
QuotaUsed: 12,
Quota: 10,
Key: "sk-test-quota",
Status: StatusAPIKeyQuotaExhausted,
},
}
cache := &quotaStateCacheStub{}
svc := &APIKeyService{
apiKeyRepo: repo,
cache: cache,
}
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
require.NoError(t, err)
require.Equal(t, 1, repo.stateCalls)
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
}
......@@ -5998,6 +5998,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
intervalCh = intervalTicker.C
}
// 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
......@@ -6187,6 +6203,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
break
}
flusher.Flush()
lastDataAt = time.Now()
}
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
......@@ -6220,6 +6237,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping 事件:Anthropic 原生格式,客户端会正确处理,
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing")
continue
}
flusher.Flush()
}
}
......
package service
import (
"fmt"
"strings"
)
......@@ -226,6 +227,29 @@ func normalizeCodexModel(model string) string {
return "gpt-5.1"
}
func SupportsVerbosity(model string) bool {
if !strings.HasPrefix(model, "gpt-") {
return true
}
var major, minor int
n, _ := fmt.Sscanf(model, "gpt-%d.%d", &major, &minor)
if major > 5 {
return true
}
if major < 5 {
return false
}
// gpt-5
if n == 1 {
return true
}
return minor >= 3
}
func getNormalizedCodexModel(modelID string) string {
if modelID == "" {
return ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment