Commit 5e060b22 authored by erio's avatar erio
Browse files

Merge remote-tracking branch 'upstream/main' into feat/channel-insights

# Conflicts:
#	backend/cmd/server/wire_gen.go
parents 6f04c25e 0a80ec80
......@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
},
}
snapshot := svc.snapshotFromAPIKey(apiKey)
snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
......
......@@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
// 创建用户
user := &User{
Email: email,
......@@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
}
......@@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
Email: email,
......@@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
......@@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
Email: email,
......@@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
......
......@@ -20,6 +20,9 @@ import (
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
......@@ -87,6 +90,8 @@ type BillingCacheService struct {
userRepo UserRepository
subRepo UserSubscriptionRepository
apiKeyRateLimitLoader apiKeyRateLimitLoader
userRPMCache UserRPMCache
userGroupRateRepo UserGroupRateRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
......@@ -104,12 +109,22 @@ type BillingCacheService struct {
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
func NewBillingCacheService(
cache BillingCache,
userRepo UserRepository,
subRepo UserSubscriptionRepository,
apiKeyRepo APIKeyRepository,
userRPMCache UserRPMCache,
userGroupRateRepo UserGroupRateRepository,
cfg *config.Config,
) *BillingCacheService {
svc := &BillingCacheService{
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
apiKeyRateLimitLoader: apiKeyRepo,
userRPMCache: userRPMCache,
userGroupRateRepo: userGroupRateRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
......@@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
}
}
// RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
if err := s.checkRPM(ctx, user, group); err != nil {
return err
}
return nil
}
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
//
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
//
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
if s == nil || s.userRPMCache == nil || user == nil {
return nil
}
// ── 第一层:分组级检查(override 或 group.rpm_limit) ──
if group != nil {
// 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
var override *int
if user.UserGroupRPMOverride != nil {
override = user.UserGroupRPMOverride
} else if s.userGroupRateRepo != nil {
dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm override lookup failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
} else {
override = dbOverride
}
}
if override != nil {
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
if *override > 0 {
count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if incErr != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (override) failed for user=%d group=%d: %v",
user.ID, group.ID, incErr,
)
// fail-open
} else if count > *override {
return ErrGroupRPMExceeded
}
}
// override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
} else if group.RPMLimit > 0 {
// 无 override,检查 group.rpm_limit。
count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (group) failed for user=%d group=%d: %v",
user.ID, group.ID, err,
)
// fail-open
} else if count > group.RPMLimit {
return ErrGroupRPMExceeded
}
}
}
// ── 第二层:用户级全局硬上限(始终生效) ──
if user.RPMLimit > 0 {
count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
if err != nil {
logger.LegacyPrintf(
"service.billing_cache",
"Warning: rpm increment (user) failed for user=%d: %v",
user.ID, err,
)
return nil // fail-open
}
if count > user.RPMLimit {
return ErrUserRPMExceeded
}
}
return nil
}
......
//go:build unit
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
type userRPMCacheStub struct {
userGroupCalls int32
userCalls int32
userGroupCounts []int // 依次返回的计数值
userGroupErr error
userCounts []int
userErr error
}
func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
if s.userGroupErr != nil {
return 0, s.userGroupErr
}
if idx < len(s.userGroupCounts) {
return s.userGroupCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
if s.userErr != nil {
return 0, s.userErr
}
if idx < len(s.userCounts) {
return s.userCounts[idx], nil
}
return 1, nil
}
func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
return 0, nil
}
func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
return 0, nil
}
// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
type rpmOverrideRepoStub struct {
UserGroupRateRepository
override *int
err error
calls int32
}
func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
atomic.AddInt32(&s.calls, 1)
if s.err != nil {
return nil, s.err
}
return s.override, nil
}
func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
t.Helper()
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
// 我们只直接测 checkRPM。
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
t.Cleanup(svc.Stop)
return svc
}
func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
override := 2
// user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
// 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
override := 100 // override 很高
// user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: &override}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100
group := &Group{ID: 10, RPMLimit: 100}
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
}
func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
zero := 0
// user 计数: 依次返回 1..6
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 5}
group := &Group{ID: 10, RPMLimit: 100}
// override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
for i := 0; i < 5; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
}
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
"override=0 跳过分组但 user 全局上限仍应生效")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
}
func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
zero := 0
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: &zero}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0} // user 也不限
group := &Group{ID: 10, RPMLimit: 100}
for i := 0; i < 50; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
}
func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
// user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超
group := &Group{ID: 10, RPMLimit: 5}
require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
// 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回")
}
func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
cache := &userRPMCacheStub{userGroupCounts: []int{3}}
repo := &rpmOverrideRepoStub{err: errors.New("db down")}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 10}
// override 查询失败后应继续尝试 group 分支(不直接拒绝)
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
}
func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 0}
for i := 0; i < 10; i++ {
require.NoError(t, svc.checkRPM(context.Background(), user, group))
}
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
repo := &rpmOverrideRepoStub{override: nil}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 0}
group := &Group{ID: 10, RPMLimit: 5}
// Redis 故障时应 fail-open,不拒绝请求
require.NoError(t, svc.checkRPM(context.Background(), user, group))
require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
}
func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
user := &User{ID: 1, RPMLimit: 2}
// 无 group(纯用户级限流场景),不应查询 rpm_override。
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.NoError(t, svc.checkRPM(context.Background(), user, nil))
require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
}
func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
cache := &userRPMCacheStub{}
repo := &rpmOverrideRepoStub{}
svc := newBillingServiceForRPM(t, cache, repo)
require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
}
......@@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16
......
......@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
......@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
......
......@@ -217,6 +217,9 @@ func (s *BillingService) initFallbackPricing() {
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
}
// GPT-5.5 暂无独立定价,回退到 GPT-5.4
s.fallbackPrices["gpt-5.5"] = s.fallbackPrices["gpt-5.4"]
s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{
InputPricePerToken: 7.5e-7,
OutputPricePerToken: 4.5e-6,
......@@ -288,6 +291,8 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
normalized := normalizeCodexModel(modelLower)
switch normalized {
case "gpt-5.5":
return s.fallbackPrices["gpt-5.5"]
case "gpt-5.4-mini":
return s.fallbackPrices["gpt-5.4-mini"]
case "gpt-5.4":
......@@ -637,7 +642,8 @@ func isOpenAIGPT54Model(model string) bool {
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
return false
}
return normalizeCodexModel(trimmed) == "gpt-5.4"
normalized := normalizeCodexModel(trimmed)
return normalized == "gpt-5.4" || normalized == "gpt-5.5"
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
......
......@@ -173,6 +173,7 @@ const (
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制)
// 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
......
......@@ -59,6 +59,10 @@ type Group struct {
DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// RPMLimit 分组级每分钟请求数上限(0 = 不限制)。
// 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。
RPMLimit int
CreatedAt time.Time
UpdatedAt time.Time
......
package service
import "context"
// OpenAI403CounterCache 追踪 OpenAI 账号连续 403 失败次数。
type OpenAI403CounterCache interface {
// IncrementOpenAI403Count 原子递增 403 计数并返回当前值。
IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error)
// ResetOpenAI403Count 成功后清零计数器。
ResetOpenAI403Count(ctx context.Context, accountID int64) error
}
......@@ -6,6 +6,7 @@ import (
)
var codexModelMap = map[string]string{
"gpt-5.5": "gpt-5.5",
"gpt-5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
"gpt-5.4-none": "gpt-5.4",
......@@ -207,6 +208,9 @@ func normalizeCodexModel(model string) string {
normalized := strings.ToLower(modelID)
if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") {
return "gpt-5.5"
}
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
return "gpt-5.4-mini"
}
......
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type openAI403CounterResetStub struct {
resetCalls []int64
}
func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) {
return 0, nil
}
func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
s.resetCalls = append(s.resetCalls, accountID)
return nil
}
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
counter := &openAI403CounterResetStub{}
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
rateLimitSvc.SetOpenAI403CounterCache(counter)
svc := &OpenAIGatewayService{
rateLimitService: rateLimitSvc,
}
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{},
Account: &Account{ID: 777, Platform: PlatformOpenAI},
})
require.NoError(t, err)
require.Equal(t, []int64{777}, counter.resetCalls)
}
......@@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) {
imagePrice := 0.02
groupID := int64(12)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_per_request",
Model: "gpt-image-2",
Usage: OpenAIUsage{
InputTokens: 1110,
OutputTokens: 1756,
ImageOutputTokens: 1756,
},
ImageCount: 2,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1008,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 1.0,
ImagePrice1K: &imagePrice,
},
},
User: &User{ID: 2008},
Account: &Account{ID: 3008},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
require.Equal(t, 2, usageRepo.lastLog.ImageCount)
require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
}
......@@ -4425,6 +4425,9 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
}
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
......@@ -4622,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
serviceTier string,
) (*CostBreakdown, error) {
if result != nil && result.ImageCount > 0 {
if hasOpenAIImageUsageTokens(result) {
cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize)
if err == nil {
return cost, nil
}
}
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
}
if s.resolver != nil && apiKey.Group != nil {
......@@ -4646,32 +4643,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
func (s *OpenAIGatewayService) calculateOpenAIImageTokenCost(
ctx context.Context,
apiKey *APIKey,
billingModel string,
multiplier float64,
tokens UsageTokens,
serviceTier string,
sizeTier string,
) (*CostBreakdown, error) {
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
return s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
SizeTier: sizeTier,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
}
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
func (s *OpenAIGatewayService) calculateOpenAIImageCost(
ctx context.Context,
billingModel string,
......@@ -4679,7 +4650,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
result *OpenAIForwardResult,
multiplier float64,
) *CostBreakdown {
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil {
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil &&
(resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) {
gid := apiKey.Group.ID
cost, err := s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
......@@ -4720,17 +4692,6 @@ func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context,
return nil
}
func hasOpenAIImageUsageTokens(result *OpenAIForwardResult) bool {
if result == nil {
return false
}
return result.Usage.InputTokens > 0 ||
result.Usage.OutputTokens > 0 ||
result.Usage.CacheCreationInputTokens > 0 ||
result.Usage.CacheReadInputTokens > 0 ||
result.Usage.ImageOutputTokens > 0
}
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
......
......@@ -5,27 +5,22 @@ import (
"bytes"
"context"
"crypto/sha256"
"crypto/sha3"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"sort"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
......@@ -40,16 +35,10 @@ const (
openAIChatGPTStartURL = "https://chatgpt.com/"
openAIChatGPTFilesURL = "https://chatgpt.com/backend-api/files"
openAIChatGPTConversationInitURL = "https://chatgpt.com/backend-api/conversation/init"
openAIChatGPTConversationURL = "https://chatgpt.com/backend-api/f/conversation"
openAIChatGPTConversationPrepareURL = "https://chatgpt.com/backend-api/f/conversation/prepare"
openAIChatGPTChatRequirementsURL = "https://chatgpt.com/backend-api/sentinel/chat-requirements"
openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
openAIImageRequirementsDiff = "0fffff"
openAIImageLifecycleTimeout = 2 * time.Minute
openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download
openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part
openAIImagesResponsesMainModel = "gpt-5.4-mini"
)
type OpenAIImagesCapability string
......@@ -81,10 +70,21 @@ type OpenAIImagesRequest struct {
ExplicitSize bool
SizeTier string
ResponseFormat string
Quality string
Background string
OutputFormat string
Moderation string
InputFidelity string
Style string
OutputCompression *int
PartialImages *int
HasMask bool
HasNativeOptions bool
RequiredCapability OpenAIImagesCapability
InputImageURLs []string
MaskImageURL string
Uploads []OpenAIImagesUpload
MaskUpload *OpenAIImagesUpload
Body []byte
bodyHash string
}
......@@ -188,7 +188,54 @@ func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error {
req.ExplicitSize = req.Size != ""
}
req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String()))
req.Quality = strings.TrimSpace(gjson.GetBytes(body, "quality").String())
req.Background = strings.TrimSpace(gjson.GetBytes(body, "background").String())
req.OutputFormat = strings.TrimSpace(gjson.GetBytes(body, "output_format").String())
req.Moderation = strings.TrimSpace(gjson.GetBytes(body, "moderation").String())
req.InputFidelity = strings.TrimSpace(gjson.GetBytes(body, "input_fidelity").String())
req.Style = strings.TrimSpace(gjson.GetBytes(body, "style").String())
req.HasMask = gjson.GetBytes(body, "mask").Exists()
if outputCompression := gjson.GetBytes(body, "output_compression"); outputCompression.Exists() {
if outputCompression.Type != gjson.Number {
return fmt.Errorf("invalid output_compression field type")
}
v := int(outputCompression.Int())
req.OutputCompression = &v
}
if partialImages := gjson.GetBytes(body, "partial_images"); partialImages.Exists() {
if partialImages.Type != gjson.Number {
return fmt.Errorf("invalid partial_images field type")
}
v := int(partialImages.Int())
req.PartialImages = &v
}
if req.IsEdits() {
images := gjson.GetBytes(body, "images")
if images.Exists() {
if !images.IsArray() {
return fmt.Errorf("invalid images field type")
}
for _, item := range images.Array() {
if imageURL := strings.TrimSpace(item.Get("image_url").String()); imageURL != "" {
req.InputImageURLs = append(req.InputImageURLs, imageURL)
continue
}
if item.Get("file_id").Exists() {
return fmt.Errorf("images[].file_id is not supported (use images[].image_url instead)")
}
}
}
if maskImageURL := strings.TrimSpace(gjson.GetBytes(body, "mask.image_url").String()); maskImageURL != "" {
req.MaskImageURL = maskImageURL
req.HasMask = true
}
if gjson.GetBytes(body, "mask.file_id").Exists() {
return fmt.Errorf("mask.file_id is not supported (use mask.image_url instead)")
}
if len(req.InputImageURLs) == 0 {
return fmt.Errorf("images[].image_url is required")
}
}
req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool {
return gjson.GetBytes(body, path).Exists()
})
......@@ -231,6 +278,16 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
if name == "mask" && len(data) > 0 {
req.HasMask = true
width, height := parseOpenAIImageDimensions(part.Header)
maskUpload := OpenAIImagesUpload{
FieldName: name,
FileName: fileName,
ContentType: partContentType,
Data: data,
Width: width,
Height: height,
}
req.MaskUpload = &maskUpload
}
if name == "image" || strings.HasPrefix(name, "image[") {
width, height := parseOpenAIImageDimensions(part.Header)
......@@ -270,6 +327,38 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope
return fmt.Errorf("n must be a positive integer")
}
req.N = n
case "quality":
req.Quality = value
req.HasNativeOptions = true
case "background":
req.Background = value
req.HasNativeOptions = true
case "output_format":
req.OutputFormat = value
req.HasNativeOptions = true
case "moderation":
req.Moderation = value
req.HasNativeOptions = true
case "input_fidelity":
req.InputFidelity = value
req.HasNativeOptions = true
case "style":
req.Style = value
req.HasNativeOptions = true
case "output_compression":
n, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("invalid output_compression field value")
}
req.OutputCompression = &n
req.HasNativeOptions = true
case "partial_images":
n, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("invalid partial_images field value")
}
req.PartialImages = &n
req.HasNativeOptions = true
default:
if isOpenAINativeImageOption(name) && value != "" {
req.HasNativeOptions = true
......@@ -359,6 +448,8 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool {
"output_format",
"output_compression",
"moderation",
"input_fidelity",
"partial_images",
} {
if exists(path) {
return true
......@@ -369,7 +460,7 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool {
func isOpenAINativeImageOption(name string) bool {
switch strings.TrimSpace(strings.ToLower(name)) {
case "background", "quality", "style", "output_format", "output_compression", "moderation":
case "background", "quality", "style", "output_format", "output_compression", "moderation", "input_fidelity", "partial_images":
return true
default:
return false
......@@ -782,563 +873,6 @@ func extractOpenAIImageCountFromJSONBytes(body []byte) int {
return 0
}
func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
ctx context.Context,
c *gin.Context,
account *Account,
parsed *OpenAIImagesRequest,
channelMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
requestModel := strings.TrimSpace(parsed.Model)
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
requestModel = mapped
}
if err := validateOpenAIImagesModel(requestModel); err != nil {
return nil, err
}
logger.LegacyPrintf(
"service.openai_gateway",
"[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d",
requestModel,
parsed.Endpoint,
account.Type,
len(parsed.Uploads),
)
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
client, err := newOpenAIBackendAPIClient(resolveOpenAIProxyURL(account))
if err != nil {
return nil, err
}
headers, err := s.buildOpenAIBackendAPIHeaders(account, token)
if err != nil {
return nil, err
}
if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil {
logger.LegacyPrintf("service.openai_gateway", "OpenAI image bootstrap failed: %v", bootstrapErr)
}
chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers)
if err != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
}
if chatReqs.Arkose.Required {
return nil, s.wrapOpenAIImageBackendError(
ctx,
c,
account,
newOpenAIImageSyntheticStatusError(
http.StatusForbidden,
"chat-requirements requires unsupported challenge (arkose)",
openAIChatGPTChatRequirementsURL,
),
)
}
parentMessageID := uuid.NewString()
proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent"))
_ = initializeOpenAIImageConversation(ctx, client, headers)
conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, parsed.Prompt, parentMessageID, chatReqs.Token, proofToken)
if err != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
}
uploads, err := uploadOpenAIImageFiles(ctx, client, headers, parsed.Uploads)
if err != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
}
convReq := buildOpenAIImageConversationRequest(parsed, parentMessageID, uploads)
if parsedContent, err := json.Marshal(convReq); err == nil {
setOpsUpstreamRequestBody(c, parsedContent)
}
convHeaders := cloneHTTPHeader(headers)
convHeaders.Set("Accept", "text/event-stream")
convHeaders.Set("Content-Type", "application/json")
convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token)
if conduitToken != "" {
convHeaders.Set("x-conduit-token", conduitToken)
}
if proofToken != "" {
convHeaders.Set("openai-sentinel-proof-token", proofToken)
}
resp, err := client.R().
SetContext(ctx).
DisableAutoReadResponse().
SetHeaders(headerToMap(convHeaders)).
SetBodyJsonMarshal(convReq).
Post(openAIChatGPTConversationURL)
if err != nil {
return nil, fmt.Errorf("openai image conversation request failed: %w", err)
}
defer func() {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
}()
if resp.StatusCode >= 400 {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, handleOpenAIImageBackendError(resp))
}
conversationID, pointerInfos, usage, firstTokenMs, err := readOpenAIImageConversationStream(resp, startTime)
if err != nil {
return nil, err
}
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
logger.LegacyPrintf(
"service.openai_gateway",
"[OpenAI] Image extraction stream conversation_id=%s total_assets=%d file_service_assets=%d direct_assets=%d",
conversationID,
len(pointerInfos),
countOpenAIFileServicePointerInfos(pointerInfos),
countOpenAIDirectImageAssets(pointerInfos),
)
lifecycleCtx, releaseLifecycleCtx := detachOpenAIImageLifecycleContext(ctx, openAIImageLifecycleTimeout)
defer releaseLifecycleCtx()
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
polledPointers, pollErr := pollOpenAIImageConversation(lifecycleCtx, client, headers, conversationID)
if pollErr != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr)
}
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers)
}
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
if len(pointerInfos) == 0 {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Image extraction yielded no assets conversation_id=%s", conversationID)
return nil, fmt.Errorf("openai image conversation returned no downloadable images")
}
responseBody, imageCount, err := buildOpenAIImageResponse(lifecycleCtx, client, headers, conversationID, pointerInfos)
if err != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
}
c.Data(http.StatusOK, "application/json; charset=utf-8", responseBody)
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: usage,
Model: requestModel,
UpstreamModel: requestModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: parsed.SizeTier,
}, nil
}
func resolveOpenAIProxyURL(account *Account) string {
if account != nil && account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}
func newOpenAIBackendAPIClient(proxyURL string) (*req.Client, error) {
client := req.C().
SetTimeout(180 * time.Second).
ImpersonateChrome()
trimmed, _, err := proxyurl.Parse(proxyURL)
if err != nil {
return nil, err
}
if trimmed != "" {
client.SetProxyURL(trimmed)
}
return client, nil
}
func (s *OpenAIGatewayService) buildOpenAIBackendAPIHeaders(account *Account, token string) (http.Header, error) {
deviceID, sessionID := s.ensureOpenAIImageSessionCredentials(context.Background(), account)
headers := make(http.Header)
headers.Set("Authorization", "Bearer "+token)
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://chatgpt.com")
headers.Set("Referer", "https://chatgpt.com/")
headers.Set("Sec-Fetch-Dest", "empty")
headers.Set("Sec-Fetch-Mode", "cors")
headers.Set("Sec-Fetch-Site", "same-origin")
headers.Set("User-Agent", openAIImageBackendUserAgent)
if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
headers.Set("User-Agent", customUA)
}
if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
headers.Set("chatgpt-account-id", chatgptAccountID)
}
if deviceID != "" {
headers.Set("oai-device-id", deviceID)
headers.Set("Cookie", "oai-did="+deviceID)
}
if sessionID != "" {
headers.Set("oai-session-id", sessionID)
}
return headers, nil
}
func (s *OpenAIGatewayService) ensureOpenAIImageSessionCredentials(ctx context.Context, account *Account) (string, string) {
if account == nil {
return "", ""
}
deviceID := account.GetOpenAIDeviceID()
sessionID := account.GetOpenAISessionID()
if deviceID != "" && sessionID != "" {
return deviceID, sessionID
}
updates := map[string]any{}
if deviceID == "" {
deviceID = uuid.NewString()
updates["openai_device_id"] = deviceID
}
if sessionID == "" {
sessionID = uuid.NewString()
updates["openai_session_id"] = sessionID
}
if account.Extra == nil {
account.Extra = map[string]any{}
}
for key, value := range updates {
account.Extra[key] = value
}
if len(updates) == 0 || s == nil || s.accountRepo == nil {
return deviceID, sessionID
}
updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := s.accountRepo.UpdateExtra(updateCtx, account.ID, updates); err != nil {
logger.LegacyPrintf("service.openai_gateway", "persist openai image session creds failed: account=%d err=%v", account.ID, err)
}
return deviceID, sessionID
}
func bootstrapOpenAIBackendAPI(ctx context.Context, client *req.Client, headers http.Header) error {
resp, err := client.R().
SetContext(ctx).
DisableAutoReadResponse().
SetHeaders(headerToMap(headers)).
Get(openAIChatGPTStartURL)
if err != nil {
return err
}
if resp != nil && resp.Body != nil {
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
}
return nil
}
func initializeOpenAIImageConversation(ctx context.Context, client *req.Client, headers http.Header) error {
payload := map[string]any{
"gizmo_id": nil,
"requested_default_model": nil,
"conversation_id": nil,
"timezone_offset_min": openAITimezoneOffsetMinutes(),
"system_hints": []string{"picture_v2"},
}
resp, err := client.R().
SetContext(ctx).
SetHeaders(headerToMap(headers)).
SetBodyJsonMarshal(payload).
Post(openAIChatGPTConversationInitURL)
if err != nil {
return err
}
if !resp.IsSuccessState() {
return newOpenAIImageStatusError(resp, "conversation init failed")
}
return nil
}
type openAIChatRequirements struct {
Token string `json:"token"`
Turnstile struct {
Required bool `json:"required"`
} `json:"turnstile"`
Arkose struct {
Required bool `json:"required"`
} `json:"arkose"`
ProofOfWork struct {
Required bool `json:"required"`
Seed string `json:"seed"`
Difficulty string `json:"difficulty"`
} `json:"proofofwork"`
}
func fetchOpenAIChatRequirements(ctx context.Context, client *req.Client, headers http.Header) (*openAIChatRequirements, error) {
var lastErr error
for _, payload := range []map[string]any{
{"p": nil},
{"p": generateOpenAIRequirementsToken(headers.Get("User-Agent"))},
} {
var result openAIChatRequirements
resp, err := client.R().
SetContext(ctx).
SetHeaders(headerToMap(headers)).
SetBodyJsonMarshal(payload).
SetSuccessResult(&result).
Post(openAIChatGPTChatRequirementsURL)
if err != nil {
lastErr = err
continue
}
if resp.IsSuccessState() && strings.TrimSpace(result.Token) != "" {
return &result, nil
}
lastErr = newOpenAIImageStatusError(resp, "chat-requirements failed")
}
if lastErr == nil {
lastErr = fmt.Errorf("chat-requirements failed")
}
return nil, lastErr
}
func prepareOpenAIImageConversation(
ctx context.Context,
client *req.Client,
headers http.Header,
prompt string,
parentMessageID string,
chatToken string,
proofToken string,
) (string, error) {
messageID := uuid.NewString()
payload := map[string]any{
"action": "next",
"client_prepare_state": "success",
"fork_from_shared_post": false,
"parent_message_id": parentMessageID,
"model": "auto",
"timezone_offset_min": openAITimezoneOffsetMinutes(),
"timezone": openAITimezoneName(),
"conversation_mode": map[string]any{"kind": "primary_assistant"},
"system_hints": []string{"picture_v2"},
"supports_buffering": true,
"supported_encodings": []string{"v1"},
"partial_query": map[string]any{
"id": messageID,
"author": map[string]any{"role": "user"},
"content": map[string]any{
"content_type": "text",
"parts": []string{coalesceOpenAIFileName(prompt, "Generate an image.")},
},
},
"client_contextual_info": map[string]any{
"app_name": "chatgpt.com",
},
}
prepareHeaders := cloneHTTPHeader(headers)
prepareHeaders.Set("Accept", "*/*")
prepareHeaders.Set("Content-Type", "application/json")
if strings.TrimSpace(chatToken) != "" {
prepareHeaders.Set("openai-sentinel-chat-requirements-token", strings.TrimSpace(chatToken))
}
if strings.TrimSpace(proofToken) != "" {
prepareHeaders.Set("openai-sentinel-proof-token", strings.TrimSpace(proofToken))
}
var result struct {
ConduitToken string `json:"conduit_token"`
}
resp, err := client.R().
SetContext(ctx).
SetHeaders(headerToMap(prepareHeaders)).
SetBodyJsonMarshal(payload).
SetSuccessResult(&result).
Post(openAIChatGPTConversationPrepareURL)
if err != nil {
return "", err
}
if !resp.IsSuccessState() {
return "", newOpenAIImageStatusError(resp, "conversation prepare failed")
}
return strings.TrimSpace(result.ConduitToken), nil
}
type openAIUploadedImage struct {
FileID string
FileName string
FileSize int
MimeType string
Width int
Height int
}
func uploadOpenAIImageFiles(ctx context.Context, client *req.Client, headers http.Header, uploads []OpenAIImagesUpload) ([]openAIUploadedImage, error) {
if len(uploads) == 0 {
return nil, nil
}
results := make([]openAIUploadedImage, 0, len(uploads))
for i := range uploads {
item := uploads[i]
fileName := coalesceOpenAIFileName(item.FileName, "image.png")
payload := map[string]any{
"file_name": fileName,
"file_size": len(item.Data),
"use_case": "multimodal",
}
var created struct {
FileID string `json:"file_id"`
UploadURL string `json:"upload_url"`
}
resp, err := client.R().
SetContext(ctx).
SetHeaders(headerToMap(headers)).
SetBodyJsonMarshal(payload).
SetSuccessResult(&created).
Post(openAIChatGPTFilesURL)
if err != nil {
return nil, err
}
if !resp.IsSuccessState() || strings.TrimSpace(created.FileID) == "" || strings.TrimSpace(created.UploadURL) == "" {
return nil, newOpenAIImageStatusError(resp, "create upload slot failed")
}
uploadHeaders := map[string]string{
"Content-Type": coalesceOpenAIFileName(item.ContentType, "application/octet-stream"),
"Origin": "https://chatgpt.com",
"x-ms-blob-type": "BlockBlob",
"x-ms-version": "2020-04-08",
"User-Agent": headers.Get("User-Agent"),
}
putResp, err := client.R().
SetContext(ctx).
SetHeaders(uploadHeaders).
SetBody(item.Data).
DisableAutoReadResponse().
Put(created.UploadURL)
if err != nil {
return nil, err
}
if putResp.Response != nil && putResp.Body != nil {
_, _ = io.Copy(io.Discard, putResp.Body)
_ = putResp.Body.Close()
}
if putResp.StatusCode < 200 || putResp.StatusCode >= 300 {
return nil, newOpenAIImageStatusError(putResp, "upload image bytes failed")
}
uploadedResp, err := client.R().
SetContext(ctx).
SetHeaders(headerToMap(headers)).
SetBodyJsonMarshal(map[string]any{}).
Post(fmt.Sprintf("%s/%s/uploaded", openAIChatGPTFilesURL, created.FileID))
if err != nil {
return nil, err
}
if !uploadedResp.IsSuccessState() {
return nil, newOpenAIImageStatusError(uploadedResp, "mark upload complete failed")
}
results = append(results, openAIUploadedImage{
FileID: created.FileID,
FileName: fileName,
FileSize: len(item.Data),
MimeType: coalesceOpenAIFileName(item.ContentType, "application/octet-stream"),
Width: item.Width,
Height: item.Height,
})
}
return results, nil
}
func coalesceOpenAIFileName(value string, fallback string) string {
value = strings.TrimSpace(value)
if value == "" {
return fallback
}
return value
}
func buildOpenAIImageConversationRequest(parsed *OpenAIImagesRequest, parentMessageID string, uploads []openAIUploadedImage) map[string]any {
parts := []any{coalesceOpenAIFileName(parsed.Prompt, "Generate an image.")}
attachments := make([]map[string]any, 0, len(uploads))
if len(uploads) > 0 {
parts = make([]any, 0, len(uploads)+1)
for _, upload := range uploads {
parts = append(parts, map[string]any{
"content_type": "image_asset_pointer",
"asset_pointer": "file-service://" + upload.FileID,
"size_bytes": upload.FileSize,
"width": upload.Width,
"height": upload.Height,
})
attachment := map[string]any{
"id": upload.FileID,
"mimeType": upload.MimeType,
"name": upload.FileName,
"size": upload.FileSize,
}
if upload.Width > 0 {
attachment["width"] = upload.Width
}
if upload.Height > 0 {
attachment["height"] = upload.Height
}
attachments = append(attachments, attachment)
}
parts = append(parts, coalesceOpenAIFileName(parsed.Prompt, "Edit this image."))
}
contentType := "text"
if len(uploads) > 0 {
contentType = "multimodal_text"
}
metadata := map[string]any{
"developer_mode_connector_ids": []any{},
"selected_github_repos": []any{},
"selected_all_github_repos": false,
"system_hints": []string{"picture_v2"},
"serialization_metadata": map[string]any{
"custom_symbol_offsets": []any{},
},
}
message := map[string]any{
"id": uuid.NewString(),
"author": map[string]any{"role": "user"},
"content": map[string]any{
"content_type": contentType,
"parts": parts,
},
"metadata": metadata,
"create_time": float64(time.Now().UnixMilli()) / 1000,
}
if len(attachments) > 0 {
metadata["attachments"] = attachments
}
return map[string]any{
"action": "next",
"client_prepare_state": "sent",
"parent_message_id": parentMessageID,
"model": "auto",
"timezone_offset_min": openAITimezoneOffsetMinutes(),
"timezone": openAITimezoneName(),
"conversation_mode": map[string]any{"kind": "primary_assistant"},
"enable_message_followups": true,
"system_hints": []string{"picture_v2"},
"supports_buffering": true,
"supported_encodings": []string{"v1"},
"paragen_cot_summary_display_override": "allow",
"force_parallel_switch": "auto",
"client_contextual_info": map[string]any{
"is_dark_mode": false,
"time_since_loaded": 200,
"page_height": 900,
"page_width": 1440,
"pixel_ratio": 1,
"screen_height": 1080,
"screen_width": 1920,
"app_name": "chatgpt.com",
},
"messages": []any{message},
}
}
type openAIImagePointerInfo struct {
Pointer string
DownloadURL string
......@@ -1347,51 +881,6 @@ type openAIImagePointerInfo struct {
Prompt string
}
type openAIImageToolMessage struct {
MessageID string
CreateTime float64
PointerInfos []openAIImagePointerInfo
}
func readOpenAIImageConversationStream(resp *req.Response, startTime time.Time) (string, []openAIImagePointerInfo, OpenAIUsage, *int, error) {
if resp == nil || resp.Body == nil {
return "", nil, OpenAIUsage{}, nil, fmt.Errorf("empty conversation response")
}
reader := bufio.NewReader(resp.Body)
var (
conversationID string
firstTokenMs *int
usage OpenAIUsage
pointers []openAIImagePointerInfo
)
for {
line, err := reader.ReadString('\n')
if strings.TrimSpace(line) != "" && firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(line, "\r\n")); ok && data != "" && data != "[DONE]" {
dataBytes := []byte(data)
if conversationID == "" {
conversationID = strings.TrimSpace(gjson.GetBytes(dataBytes, "v.conversation_id").String())
if conversationID == "" {
conversationID = strings.TrimSpace(gjson.GetBytes(dataBytes, "conversation_id").String())
}
}
mergeOpenAIUsage(&usage, dataBytes)
pointers = mergeOpenAIImagePointerInfos(pointers, collectOpenAIImagePointers(dataBytes))
}
if err == io.EOF {
break
}
if err != nil {
return "", nil, OpenAIUsage{}, firstTokenMs, err
}
}
return conversationID, pointers, usage, firstTokenMs, nil
}
func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
if len(body) == 0 {
return nil
......@@ -1517,222 +1006,6 @@ func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIIm
return merged
}
func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool {
for _, item := range items {
if strings.HasPrefix(item.Pointer, "file-service://") {
return true
}
}
return false
}
func countOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) int {
count := 0
for _, item := range items {
if strings.HasPrefix(item.Pointer, "file-service://") {
count++
}
}
return count
}
func countOpenAIDirectImageAssets(items []openAIImagePointerInfo) int {
count := 0
for _, item := range items {
if strings.TrimSpace(item.DownloadURL) != "" || strings.TrimSpace(item.B64JSON) != "" {
count++
}
}
return count
}
func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo {
if !hasOpenAIFileServicePointerInfos(items) {
return items
}
out := make([]openAIImagePointerInfo, 0, len(items))
for _, item := range items {
if strings.HasPrefix(item.Pointer, "file-service://") {
out = append(out, item)
}
}
return out
}
func extractOpenAIImageToolMessages(mapping map[string]any) []openAIImageToolMessage {
if len(mapping) == 0 {
return nil
}
out := make([]openAIImageToolMessage, 0, 4)
for messageID, raw := range mapping {
node, _ := raw.(map[string]any)
if node == nil {
continue
}
message, _ := node["message"].(map[string]any)
if message == nil {
continue
}
author, _ := message["author"].(map[string]any)
metadata, _ := message["metadata"].(map[string]any)
content, _ := message["content"].(map[string]any)
if author == nil || metadata == nil || content == nil {
continue
}
if role, _ := author["role"].(string); role != "tool" {
continue
}
if asyncTaskType, _ := metadata["async_task_type"].(string); asyncTaskType != "image_gen" {
continue
}
if contentType, _ := content["content_type"].(string); contentType != "multimodal_text" {
continue
}
prompt := ""
if title, _ := metadata["image_gen_title"].(string); strings.TrimSpace(title) != "" {
prompt = strings.TrimSpace(title)
}
item := openAIImageToolMessage{MessageID: messageID}
if createTime, ok := message["create_time"].(float64); ok {
item.CreateTime = createTime
}
parts, _ := content["parts"].([]any)
for _, part := range parts {
switch value := part.(type) {
case map[string]any:
if assetPointer, _ := value["asset_pointer"].(string); strings.TrimSpace(assetPointer) != "" {
for _, pointer := range openAIImagePointerMatches([]byte(assetPointer)) {
item.PointerInfos = append(item.PointerInfos, openAIImagePointerInfo{
Pointer: pointer,
Prompt: prompt,
})
}
}
case string:
for _, pointer := range openAIImagePointerMatches([]byte(value)) {
item.PointerInfos = append(item.PointerInfos, openAIImagePointerInfo{
Pointer: pointer,
Prompt: prompt,
})
}
}
}
if len(item.PointerInfos) == 0 {
continue
}
item.PointerInfos = mergeOpenAIImagePointerInfos(nil, item.PointerInfos)
out = append(out, item)
}
sort.Slice(out, func(i, j int) bool {
return out[i].CreateTime < out[j].CreateTime
})
return out
}
func pollOpenAIImageConversation(ctx context.Context, client *req.Client, headers http.Header, conversationID string) ([]openAIImagePointerInfo, error) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return nil, nil
}
deadline := time.Now().Add(90 * time.Second)
interval := 3 * time.Second
previewWait := 15 * time.Second
var (
lastErr error
firstToolAt time.Time
)
for time.Now().Before(deadline) {
resp, err := client.R().
SetContext(ctx).
SetHeaders(headerToMap(headers)).
DisableAutoReadResponse().
Get(fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s", conversationID))
if err != nil {
lastErr = err
} else {
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
body, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
lastErr = readErr
goto waitNextPoll
}
pointers := mergeOpenAIImagePointerInfos(nil, collectOpenAIImagePointers(body))
var decoded map[string]any
if err := json.Unmarshal(body, &decoded); err == nil {
if mapping, _ := decoded["mapping"].(map[string]any); len(mapping) > 0 {
toolMessages := extractOpenAIImageToolMessages(mapping)
if len(toolMessages) > 0 && firstToolAt.IsZero() {
firstToolAt = time.Now()
}
for _, msg := range toolMessages {
pointers = mergeOpenAIImagePointerInfos(pointers, msg.PointerInfos)
}
}
}
if hasOpenAIFileServicePointerInfos(pointers) {
return preferOpenAIFileServicePointerInfos(pointers), nil
}
if len(pointers) > 0 && !firstToolAt.IsZero() && time.Since(firstToolAt) >= previewWait {
return pointers, nil
}
} else {
statusErr := newOpenAIImageStatusError(resp, "conversation poll failed")
if isOpenAIImageTransientConversationNotFoundError(statusErr) {
lastErr = statusErr
goto waitNextPoll
}
return nil, statusErr
}
}
waitNextPoll:
timer := time.NewTimer(interval)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return nil, ctx.Err()
case <-timer.C:
}
}
return nil, lastErr
}
func buildOpenAIImageResponse(
ctx context.Context,
client *req.Client,
headers http.Header,
conversationID string,
pointers []openAIImagePointerInfo,
) ([]byte, int, error) {
type responseItem struct {
B64JSON string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
items := make([]responseItem, 0, len(pointers))
for _, pointer := range pointers {
data, err := resolveOpenAIImageBytes(ctx, client, headers, conversationID, pointer)
if err != nil {
return nil, 0, err
}
items = append(items, responseItem{
B64JSON: base64.StdEncoding.EncodeToString(data),
RevisedPrompt: pointer.Prompt,
})
}
payload := map[string]any{
"created": time.Now().Unix(),
"data": items,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, 0, err
}
return body, len(items), nil
}
func resolveOpenAIImageBytes(
ctx context.Context,
client *req.Client,
......@@ -1852,17 +1125,6 @@ func isLikelyOpenAIImageDownloadURL(raw string) bool {
strings.Contains(lower, ".webp")
}
func detachOpenAIImageLifecycleContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
base := context.Background()
if ctx != nil {
base = context.WithoutCancel(ctx)
}
if timeout <= 0 {
return base, func() {}
}
return context.WithTimeout(base, timeout)
}
func fetchOpenAIImageDownloadURL(
ctx context.Context,
client *req.Client,
......@@ -1957,10 +1219,6 @@ func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers h
return io.ReadAll(io.LimitReader(resp.Body, openAIImageMaxDownloadBytes))
}
func handleOpenAIImageBackendError(resp *req.Response) error {
return newOpenAIImageStatusError(resp, "backend-api request failed")
}
type openAIImageStatusError struct {
StatusCode int
Message string
......@@ -2028,23 +1286,6 @@ func newOpenAIImageStatusError(resp *req.Response, fallback string) error {
}
}
func newOpenAIImageSyntheticStatusError(statusCode int, message string, requestURL string) *openAIImageStatusError {
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
if message == "" {
message = "openai image backend request failed"
}
var body []byte
if payload, err := json.Marshal(map[string]string{"detail": message}); err == nil {
body = payload
}
return &openAIImageStatusError{
StatusCode: statusCode,
Message: message,
ResponseBody: body,
URL: strings.TrimSpace(requestURL),
}
}
func isOpenAIImageTransientConversationNotFoundError(err error) bool {
statusErr, ok := err.(*openAIImageStatusError)
if !ok || statusErr == nil || statusErr.StatusCode != http.StatusNotFound {
......@@ -2064,58 +1305,6 @@ func isOpenAIImageTransientConversationNotFoundError(err error) bool {
return strings.Contains(bodyMsg, "conversation") && strings.Contains(bodyMsg, "not found")
}
func (s *OpenAIGatewayService) wrapOpenAIImageBackendError(
ctx context.Context,
c *gin.Context,
account *Account,
err error,
) error {
var statusErr *openAIImageStatusError
if !errors.As(err, &statusErr) || statusErr == nil {
return err
}
upstreamMsg := sanitizeUpstreamErrorMessage(statusErr.Message)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: statusErr.StatusCode,
UpstreamRequestID: statusErr.RequestID,
UpstreamURL: safeUpstreamURL(statusErr.URL),
Kind: "request_error",
Message: upstreamMsg,
})
setOpsUpstreamError(c, statusErr.StatusCode, upstreamMsg, "")
if s.shouldFailoverOpenAIUpstreamResponse(statusErr.StatusCode, upstreamMsg, statusErr.ResponseBody) {
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, statusErr.StatusCode, statusErr.ResponseHeaders, statusErr.ResponseBody)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: statusErr.StatusCode,
UpstreamRequestID: statusErr.RequestID,
UpstreamURL: safeUpstreamURL(statusErr.URL),
Kind: "failover",
Message: upstreamMsg,
})
retryableOnSameAccount := account.IsPoolMode() && isPoolModeRetryableStatus(statusErr.StatusCode)
if strings.Contains(strings.ToLower(statusErr.Message), "unsupported challenge") {
retryableOnSameAccount = false
}
return &UpstreamFailoverError{
StatusCode: statusErr.StatusCode,
ResponseBody: statusErr.ResponseBody,
RetryableOnSameAccount: retryableOnSameAccount,
}
}
return statusErr
}
func cloneHTTPHeader(src http.Header) http.Header {
dst := make(http.Header, len(src))
for key, values := range src {
......@@ -2140,110 +1329,6 @@ func headerToMap(header http.Header) map[string]string {
return result
}
func openAITimezoneOffsetMinutes() int {
_, offset := time.Now().Zone()
return offset / 60
}
func openAITimezoneName() string {
return time.Now().Location().String()
}
func generateOpenAIRequirementsToken(userAgent string) string {
config := []any{
"core" + strconv.Itoa(3008),
time.Now().UTC().Format(time.RFC1123),
nil,
0.123456,
coalesceOpenAIFileName(strings.TrimSpace(userAgent), openAIImageBackendUserAgent),
nil,
"prod-openai-images",
"en-US",
"en-US,en",
0,
"navigator.webdriver",
"location",
"document.body",
float64(time.Now().UnixMilli()) / 1000,
uuid.NewString(),
"",
8,
time.Now().Unix(),
}
answer, solved := generateOpenAIChallengeAnswer(strconv.FormatInt(time.Now().UnixNano(), 10), openAIImageRequirementsDiff, config)
if solved {
return "gAAAAAC" + answer
}
return ""
}
func generateOpenAIChallengeAnswer(seed string, difficulty string, config []any) (string, bool) {
diffBytes, err := hex.DecodeString(difficulty)
if err != nil {
return "", false
}
p1 := []byte(jsonCompactSlice(config[:3], true))
p2 := []byte(jsonCompactSlice(config[4:9], false))
p3 := []byte(jsonCompactSlice(config[10:], false))
seedBytes := []byte(seed)
for i := 0; i < 100000; i++ {
payload := fmt.Sprintf("%s%d,%s,%d,%s", p1, i, p2, i>>1, p3)
encoded := base64.StdEncoding.EncodeToString([]byte(payload))
sum := sha3.Sum512(append(seedBytes, []byte(encoded)...))
if bytes.Compare(sum[:len(diffBytes)], diffBytes) <= 0 {
return encoded, true
}
}
return "", false
}
func jsonCompactSlice(values []any, trimSuffixComma bool) string {
raw, _ := json.Marshal(values)
text := string(raw)
if trimSuffixComma {
return strings.TrimSuffix(text, "]")
}
return strings.TrimPrefix(text, "[")
}
func generateOpenAIProofToken(required bool, seed string, difficulty string, userAgent string) string {
if !required || strings.TrimSpace(seed) == "" || strings.TrimSpace(difficulty) == "" {
return ""
}
screen := 3008
if len(seed)%2 == 0 {
screen = 4010
}
proofToken := []any{
screen,
time.Now().UTC().Format(time.RFC1123),
nil,
0,
coalesceOpenAIFileName(strings.TrimSpace(userAgent), openAIImageBackendUserAgent),
"https://chatgpt.com/",
"dpl=openai-images",
"en",
"en-US",
nil,
"plugins[object PluginArray]",
"_reactListening",
"alert",
}
diffLen := len(difficulty)
for i := 0; i < 100000; i++ {
proofToken[3] = i
raw, _ := json.Marshal(proofToken)
encoded := base64.StdEncoding.EncodeToString(raw)
sum := sha3.Sum512([]byte(seed + encoded))
if strings.Compare(hex.EncodeToString(sum[:])[:diffLen], difficulty) <= 0 {
return "gAAAAAB" + encoded
}
}
fallbackBase := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%q", seed)))
return "gAAAAABwQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + fallbackBase
}
func dedupeStrings(values []string) []string {
if len(values) == 0 {
return nil
......
package service
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type openAIResponsesImageResult struct {
Result string
RevisedPrompt string
OutputFormat string
Size string
Background string
Quality string
Model string
}
func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string {
if strings.TrimSpace(result.Result) != "" {
return strings.TrimSpace(result.OutputFormat) + "|" + strings.TrimSpace(result.Result)
}
return "item:" + strings.TrimSpace(itemID)
}
func appendOpenAIResponsesImageResultDedup(results *[]openAIResponsesImageResult, seen map[string]struct{}, itemID string, result openAIResponsesImageResult) bool {
if results == nil {
return false
}
key := openAIResponsesImageResultKey(itemID, result)
if key != "" {
if _, exists := seen[key]; exists {
return false
}
seen[key] = struct{}{}
}
*results = append(*results, result)
return true
}
func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIResponsesImageResult) {
if dst == nil {
return
}
if trimmed := strings.TrimSpace(src.OutputFormat); trimmed != "" {
dst.OutputFormat = trimmed
}
if trimmed := strings.TrimSpace(src.Size); trimmed != "" {
dst.Size = trimmed
}
if trimmed := strings.TrimSpace(src.Background); trimmed != "" {
dst.Background = trimmed
}
if trimmed := strings.TrimSpace(src.Quality); trimmed != "" {
dst.Quality = trimmed
}
if trimmed := strings.TrimSpace(src.Model); trimmed != "" {
dst.Model = trimmed
}
}
func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) {
switch gjson.GetBytes(payload, "type").String() {
case "response.created", "response.in_progress", "response.completed":
default:
return openAIResponsesImageResult{}, 0, false
}
response := gjson.GetBytes(payload, "response")
if !response.Exists() {
return openAIResponsesImageResult{}, 0, false
}
meta := openAIResponsesImageResult{
OutputFormat: strings.TrimSpace(response.Get("tools.0.output_format").String()),
Size: strings.TrimSpace(response.Get("tools.0.size").String()),
Background: strings.TrimSpace(response.Get("tools.0.background").String()),
Quality: strings.TrimSpace(response.Get("tools.0.quality").String()),
Model: strings.TrimSpace(response.Get("tools.0.model").String()),
}
return meta, response.Get("created_at").Int(), true
}
func buildOpenAIImagesStreamPartialPayload(
eventType string,
b64 string,
partialImageIndex int64,
responseFormat string,
createdAt int64,
meta openAIResponsesImageResult,
) []byte {
if createdAt <= 0 {
createdAt = time.Now().Unix()
}
payload := []byte(`{"type":"","created_at":0,"partial_image_index":0,"b64_json":""}`)
payload, _ = sjson.SetBytes(payload, "type", eventType)
payload, _ = sjson.SetBytes(payload, "created_at", createdAt)
payload, _ = sjson.SetBytes(payload, "partial_image_index", partialImageIndex)
payload, _ = sjson.SetBytes(payload, "b64_json", b64)
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(meta.OutputFormat)+";base64,"+b64)
}
if meta.Background != "" {
payload, _ = sjson.SetBytes(payload, "background", meta.Background)
}
if meta.OutputFormat != "" {
payload, _ = sjson.SetBytes(payload, "output_format", meta.OutputFormat)
}
if meta.Quality != "" {
payload, _ = sjson.SetBytes(payload, "quality", meta.Quality)
}
if meta.Size != "" {
payload, _ = sjson.SetBytes(payload, "size", meta.Size)
}
if meta.Model != "" {
payload, _ = sjson.SetBytes(payload, "model", meta.Model)
}
return payload
}
func buildOpenAIImagesStreamCompletedPayload(
eventType string,
img openAIResponsesImageResult,
responseFormat string,
createdAt int64,
usageRaw []byte,
) []byte {
if createdAt <= 0 {
createdAt = time.Now().Unix()
}
payload := []byte(`{"type":"","created_at":0,"b64_json":""}`)
payload, _ = sjson.SetBytes(payload, "type", eventType)
payload, _ = sjson.SetBytes(payload, "created_at", createdAt)
payload, _ = sjson.SetBytes(payload, "b64_json", img.Result)
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result)
}
if img.Background != "" {
payload, _ = sjson.SetBytes(payload, "background", img.Background)
}
if img.OutputFormat != "" {
payload, _ = sjson.SetBytes(payload, "output_format", img.OutputFormat)
}
if img.Quality != "" {
payload, _ = sjson.SetBytes(payload, "quality", img.Quality)
}
if img.Size != "" {
payload, _ = sjson.SetBytes(payload, "size", img.Size)
}
if img.Model != "" {
payload, _ = sjson.SetBytes(payload, "model", img.Model)
}
if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw)
}
return payload
}
func openAIImageOutputMIMEType(outputFormat string) string {
if outputFormat == "" {
return "image/png"
}
if strings.Contains(outputFormat, "/") {
return outputFormat
}
switch strings.ToLower(strings.TrimSpace(outputFormat)) {
case "png":
return "image/png"
case "jpg", "jpeg":
return "image/jpeg"
case "webp":
return "image/webp"
default:
return "image/png"
}
}
func openAIImageUploadToDataURL(upload OpenAIImagesUpload) (string, error) {
if len(upload.Data) == 0 {
return "", fmt.Errorf("upload %q is empty", strings.TrimSpace(upload.FileName))
}
contentType := strings.TrimSpace(upload.ContentType)
if contentType == "" {
contentType = http.DetectContentType(upload.Data)
}
return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(upload.Data), nil
}
func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel string) ([]byte, error) {
if parsed == nil {
return nil, fmt.Errorf("parsed images request is required")
}
prompt := strings.TrimSpace(parsed.Prompt)
if prompt == "" {
return nil, fmt.Errorf("prompt is required")
}
inputImages := make([]string, 0, len(parsed.InputImageURLs)+len(parsed.Uploads))
for _, imageURL := range parsed.InputImageURLs {
if trimmed := strings.TrimSpace(imageURL); trimmed != "" {
inputImages = append(inputImages, trimmed)
}
}
for _, upload := range parsed.Uploads {
dataURL, err := openAIImageUploadToDataURL(upload)
if err != nil {
return nil, err
}
inputImages = append(inputImages, dataURL)
}
if parsed.IsEdits() && len(inputImages) == 0 {
return nil, fmt.Errorf("image input is required")
}
req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`)
req, _ = sjson.SetBytes(req, "model", openAIImagesResponsesMainModel)
input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`)
input, _ = sjson.SetBytes(input, "0.content.0.text", prompt)
for index, imageURL := range inputImages {
part := []byte(`{"type":"input_image","image_url":""}`)
part, _ = sjson.SetBytes(part, "image_url", imageURL)
input, _ = sjson.SetRawBytes(input, fmt.Sprintf("0.content.%d", index+1), part)
}
req, _ = sjson.SetRawBytes(req, "input", input)
action := "generate"
if parsed.IsEdits() {
action = "edit"
}
tool := []byte(`{"type":"image_generation","action":"","model":""}`)
tool, _ = sjson.SetBytes(tool, "action", action)
tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel))
for _, field := range []struct {
path string
value string
}{
{path: "size", value: parsed.Size},
{path: "quality", value: parsed.Quality},
{path: "background", value: parsed.Background},
{path: "output_format", value: parsed.OutputFormat},
{path: "moderation", value: parsed.Moderation},
{path: "style", value: parsed.Style},
} {
if trimmed := strings.TrimSpace(field.value); trimmed != "" {
tool, _ = sjson.SetBytes(tool, field.path, trimmed)
}
}
if parsed.OutputCompression != nil {
tool, _ = sjson.SetBytes(tool, "output_compression", *parsed.OutputCompression)
}
if parsed.PartialImages != nil {
tool, _ = sjson.SetBytes(tool, "partial_images", *parsed.PartialImages)
}
maskImageURL := strings.TrimSpace(parsed.MaskImageURL)
if parsed.MaskUpload != nil {
dataURL, err := openAIImageUploadToDataURL(*parsed.MaskUpload)
if err != nil {
return nil, err
}
maskImageURL = dataURL
}
if maskImageURL != "" {
tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", maskImageURL)
}
req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`))
req, _ = sjson.SetRawBytes(req, "tools.-1", tool)
return req, nil
}
func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) {
if gjson.GetBytes(payload, "type").String() != "response.completed" {
return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type")
}
createdAt := gjson.GetBytes(payload, "response.created_at").Int()
if createdAt <= 0 {
createdAt = time.Now().Unix()
}
var (
results []openAIResponsesImageResult
firstMeta openAIResponsesImageResult
)
output := gjson.GetBytes(payload, "response.output")
if output.IsArray() {
for _, item := range output.Array() {
if item.Get("type").String() != "image_generation_call" {
continue
}
result := strings.TrimSpace(item.Get("result").String())
if result == "" {
continue
}
entry := openAIResponsesImageResult{
Result: result,
RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
OutputFormat: strings.TrimSpace(item.Get("output_format").String()),
Size: strings.TrimSpace(item.Get("size").String()),
Background: strings.TrimSpace(item.Get("background").String()),
Quality: strings.TrimSpace(item.Get("quality").String()),
}
if len(results) == 0 {
firstMeta = entry
}
results = append(results, entry)
}
}
var usageRaw []byte
if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() {
usageRaw = []byte(usage.Raw)
}
return results, createdAt, usageRaw, firstMeta, nil
}
func extractOpenAIImageFromResponsesOutputItemDone(payload []byte) (openAIResponsesImageResult, string, bool, error) {
if gjson.GetBytes(payload, "type").String() != "response.output_item.done" {
return openAIResponsesImageResult{}, "", false, fmt.Errorf("unexpected event type")
}
item := gjson.GetBytes(payload, "item")
if !item.Exists() || item.Get("type").String() != "image_generation_call" {
return openAIResponsesImageResult{}, "", false, nil
}
result := strings.TrimSpace(item.Get("result").String())
if result == "" {
return openAIResponsesImageResult{}, "", false, nil
}
entry := openAIResponsesImageResult{
Result: result,
RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
OutputFormat: strings.TrimSpace(item.Get("output_format").String()),
Size: strings.TrimSpace(item.Get("size").String()),
Background: strings.TrimSpace(item.Get("background").String()),
Quality: strings.TrimSpace(item.Get("quality").String()),
}
return entry, strings.TrimSpace(item.Get("id").String()), true, nil
}
func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, bool, error) {
var (
fallbackResults []openAIResponsesImageResult
fallbackSeen = make(map[string]struct{})
createdAt int64
usageRaw []byte
foundFinal bool
responseMeta openAIResponsesImageResult
)
for _, line := range bytes.Split(body, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
data, ok := extractOpenAISSEDataLine(string(line))
if !ok || data == "" || data == "[DONE]" {
continue
}
payload := []byte(data)
if !gjson.ValidBytes(payload) {
continue
}
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok {
mergeOpenAIResponsesImageMeta(&responseMeta, meta)
if eventCreatedAt > 0 {
createdAt = eventCreatedAt
}
}
switch gjson.GetBytes(payload, "type").String() {
case "response.output_item.done":
result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
if err != nil {
return nil, 0, nil, openAIResponsesImageResult{}, false, err
}
if ok {
mergeOpenAIResponsesImageMeta(&result, responseMeta)
appendOpenAIResponsesImageResultDedup(&fallbackResults, fallbackSeen, itemID, result)
}
case "response.completed":
results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload)
if err != nil {
return nil, 0, nil, openAIResponsesImageResult{}, false, err
}
foundFinal = true
if completedAt > 0 {
createdAt = completedAt
}
if len(completedUsageRaw) > 0 {
usageRaw = completedUsageRaw
}
if len(results) > 0 {
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
return results, createdAt, usageRaw, firstMeta, true, nil
}
if len(fallbackResults) > 0 {
firstMeta = fallbackResults[0]
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
return fallbackResults, createdAt, usageRaw, firstMeta, true, nil
}
}
}
if len(fallbackResults) > 0 {
firstMeta := fallbackResults[0]
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
return fallbackResults, createdAt, usageRaw, firstMeta, foundFinal, nil
}
return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil
}
func buildOpenAIImagesAPIResponse(
results []openAIResponsesImageResult,
createdAt int64,
usageRaw []byte,
firstMeta openAIResponsesImageResult,
responseFormat string,
) ([]byte, error) {
if createdAt <= 0 {
createdAt = time.Now().Unix()
}
out := []byte(`{"created":0,"data":[]}`)
out, _ = sjson.SetBytes(out, "created", createdAt)
format := strings.ToLower(strings.TrimSpace(responseFormat))
if format == "" {
format = "b64_json"
}
for _, img := range results {
item := []byte(`{}`)
if format == "url" {
item, _ = sjson.SetBytes(item, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result)
} else {
item, _ = sjson.SetBytes(item, "b64_json", img.Result)
}
if img.RevisedPrompt != "" {
item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt)
}
out, _ = sjson.SetRawBytes(out, "data.-1", item)
}
if firstMeta.Background != "" {
out, _ = sjson.SetBytes(out, "background", firstMeta.Background)
}
if firstMeta.OutputFormat != "" {
out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat)
}
if firstMeta.Quality != "" {
out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality)
}
if firstMeta.Size != "" {
out, _ = sjson.SetBytes(out, "size", firstMeta.Size)
}
if firstMeta.Model != "" {
out, _ = sjson.SetBytes(out, "model", firstMeta.Model)
}
if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
}
return out, nil
}
func openAIImagesStreamPrefix(parsed *OpenAIImagesRequest) string {
if parsed != nil && parsed.IsEdits() {
return "image_edit"
}
return "image_generation"
}
func buildOpenAIImagesStreamErrorBody(message string) []byte {
body := []byte(`{"type":"error","error":{"type":"upstream_error","message":""}}`)
if strings.TrimSpace(message) == "" {
message = "upstream request failed"
}
body, _ = sjson.SetBytes(body, "error.message", message)
return body
}
func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flusher http.Flusher, eventName string, payload []byte) error {
if strings.TrimSpace(eventName) != "" {
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
return err
}
}
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
return err
}
flusher.Flush()
return nil
}
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
resp *http.Response,
c *gin.Context,
responseFormat string,
fallbackModel string,
) (OpenAIUsage, int, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
return OpenAIUsage{}, 0, err
}
var usage OpenAIUsage
for _, line := range bytes.Split(body, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
data, ok := extractOpenAISSEDataLine(string(line))
if !ok || data == "" || data == "[DONE]" {
continue
}
dataBytes := []byte(data)
s.parseSSEUsageBytes(dataBytes, &usage)
}
results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body)
if err != nil {
return OpenAIUsage{}, 0, err
}
if len(results) == 0 {
return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output")
}
if strings.TrimSpace(firstMeta.Model) == "" {
firstMeta.Model = strings.TrimSpace(fallbackModel)
}
responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat)
if err != nil {
return OpenAIUsage{}, 0, err
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
c.Data(resp.StatusCode, "application/json; charset=utf-8", responseBody)
return usage, len(results), nil
}
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
resp *http.Response,
c *gin.Context,
startTime time.Time,
responseFormat string,
streamPrefix string,
fallbackModel string,
) (OpenAIUsage, int, *int, error) {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Status(resp.StatusCode)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
}
format := strings.ToLower(strings.TrimSpace(responseFormat))
if format == "" {
format = "b64_json"
}
reader := bufio.NewReader(resp.Body)
usage := OpenAIUsage{}
imageCount := 0
var firstTokenMs *int
emitted := make(map[string]struct{})
pendingResults := make([]openAIResponsesImageResult, 0, 1)
pendingSeen := make(map[string]struct{})
streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)}
var createdAt int64
for {
line, err := reader.ReadBytes('\n')
if len(line) > 0 {
trimmedLine := strings.TrimRight(string(line), "\r\n")
data, ok := extractOpenAISSEDataLine(trimmedLine)
if ok && data != "" && data != "[DONE]" {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
dataBytes := []byte(data)
s.parseSSEUsageBytes(dataBytes, &usage)
if gjson.ValidBytes(dataBytes) {
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
if eventCreatedAt > 0 {
createdAt = eventCreatedAt
}
}
switch gjson.GetBytes(dataBytes, "type").String() {
case "response.image_generation_call.partial_image":
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
if b64 != "" {
eventName := streamPrefix + ".partial_image"
partialMeta := streamMeta
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
})
payload := buildOpenAIImagesStreamPartialPayload(
eventName,
b64,
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
format,
createdAt,
partialMeta,
)
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
}
}
case "response.output_item.done":
img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
if extractErr != nil {
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
}
if !ok {
break
}
mergeOpenAIResponsesImageMeta(&streamMeta, img)
mergeOpenAIResponsesImageMeta(&img, streamMeta)
key := openAIResponsesImageResultKey(itemID, img)
if _, exists := emitted[key]; exists {
break
}
if _, exists := pendingSeen[key]; exists {
break
}
pendingSeen[key] = struct{}{}
pendingResults = append(pendingResults, img)
case "response.completed":
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
if extractErr != nil {
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
}
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
finalSeen := make(map[string]struct{})
for _, img := range results {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
for _, img := range pendingResults {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
}
if len(finalResults) == 0 {
err = fmt.Errorf("upstream did not return image output")
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, err
}
eventName := streamPrefix + ".completed"
for _, img := range finalResults {
key := openAIResponsesImageResultKey("", img)
if _, exists := emitted[key]; exists {
continue
}
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
}
emitted[key] = struct{}{}
}
imageCount = len(emitted)
return usage, imageCount, firstTokenMs, nil
}
}
}
}
if err == io.EOF {
break
}
if err != nil {
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, err
}
}
if imageCount > 0 {
return usage, imageCount, firstTokenMs, nil
}
if len(pendingResults) > 0 {
eventName := streamPrefix + ".completed"
for _, img := range pendingResults {
mergeOpenAIResponsesImageMeta(&img, streamMeta)
key := openAIResponsesImageResultKey("", img)
if _, exists := emitted[key]; exists {
continue
}
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
}
emitted[key] = struct{}{}
}
imageCount = len(emitted)
return usage, imageCount, firstTokenMs, nil
}
streamErr := fmt.Errorf("stream disconnected before image generation completed")
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
return OpenAIUsage{}, imageCount, firstTokenMs, streamErr
}
func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
ctx context.Context,
c *gin.Context,
account *Account,
parsed *OpenAIImagesRequest,
channelMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
requestModel := strings.TrimSpace(parsed.Model)
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
requestModel = mapped
}
if requestModel == "" {
requestModel = "gpt-image-2"
}
if err := validateOpenAIImagesModel(requestModel); err != nil {
return nil, err
}
logger.LegacyPrintf(
"service.openai_gateway",
"[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d",
requestModel,
parsed.Endpoint,
account.Type,
len(parsed.Uploads),
)
if parsed.N > 1 {
logger.LegacyPrintf(
"service.openai_gateway",
"[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s",
parsed.N,
requestModel,
parsed.Endpoint,
)
}
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, requestModel)
if err != nil {
return nil, err
}
setOpsUpstreamRequestBody(c, responsesBody)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
if err != nil {
return nil, err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Accept", "text/event-stream")
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error",
Message: safeErr,
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "failover",
Message: upstreamMsg,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleErrorResponse(ctx, resp, c, account, responsesBody)
}
defer func() { _ = resp.Body.Close() }()
var (
usage OpenAIUsage
imageCount int
firstTokenMs *int
)
if parsed.Stream {
usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
if err != nil {
return nil, err
}
} else {
usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel)
if err != nil {
return nil, err
}
}
if imageCount <= 0 {
imageCount = parsed.N
}
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: usage,
Model: requestModel,
UpstreamModel: requestModel,
Stream: parsed.Stream,
ResponseHeaders: resp.Header.Clone(),
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: parsed.SizeTier,
}, nil
}
......@@ -3,13 +3,17 @@ package service
import (
"bytes"
"context"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/textproto"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
......@@ -70,6 +74,58 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) {
gin.SetMode(gin.TestMode)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
require.NoError(t, writer.WriteField("prompt", "replace foreground"))
require.NoError(t, writer.WriteField("output_format", "png"))
require.NoError(t, writer.WriteField("input_fidelity", "high"))
require.NoError(t, writer.WriteField("output_compression", "80"))
require.NoError(t, writer.WriteField("partial_images", "2"))
imageHeader := make(textproto.MIMEHeader)
imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`)
imageHeader.Set("Content-Type", "image/png")
imagePart, err := writer.CreatePart(imageHeader)
require.NoError(t, err)
_, err = imagePart.Write([]byte("source-image-bytes"))
require.NoError(t, err)
maskHeader := make(textproto.MIMEHeader)
maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`)
maskHeader.Set("Content-Type", "image/png")
maskPart, err := writer.CreatePart(maskHeader)
require.NoError(t, err)
_, err = maskPart.Write([]byte("mask-image-bytes"))
require.NoError(t, err)
require.NoError(t, writer.Close())
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
req.Header.Set("Content-Type", writer.FormDataContentType())
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
require.NoError(t, err)
require.NotNil(t, parsed)
require.Len(t, parsed.Uploads, 1)
require.NotNil(t, parsed.MaskUpload)
require.True(t, parsed.HasMask)
require.Equal(t, "png", parsed.OutputFormat)
require.Equal(t, "high", parsed.InputFidelity)
require.NotNil(t, parsed.OutputCompression)
require.Equal(t, 80, *parsed.OutputCompression)
require.NotNil(t, parsed.PartialImages)
require.Equal(t, 2, *parsed.PartialImages)
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"prompt":"draw a cat"}`)
......@@ -121,6 +177,40 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *te
require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`)
}
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSONEditURLs(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{
"model":"gpt-image-2",
"prompt":"replace the background",
"images":[{"image_url":"https://example.com/source.png"}],
"mask":{"image_url":"https://example.com/mask.png"},
"input_fidelity":"high",
"output_compression":90,
"partial_images":2,
"response_format":"url"
}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
require.NotNil(t, parsed)
require.Equal(t, []string{"https://example.com/source.png"}, parsed.InputImageURLs)
require.Equal(t, "https://example.com/mask.png", parsed.MaskImageURL)
require.Equal(t, "high", parsed.InputFidelity)
require.NotNil(t, parsed.OutputCompression)
require.Equal(t, 90, *parsed.OutputCompression)
require.NotNil(t, parsed.PartialImages)
require.Equal(t, 2, *parsed.PartialImages)
require.True(t, parsed.HasMask)
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
}
func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) {
items := collectOpenAIImagePointers([]byte(`{
"revised_prompt": "cat astronaut",
......@@ -157,3 +247,472 @@ func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []byte("ABC"), data)
}
func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityBasic))
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
}
type openAIImageTestSSEEvent struct {
Name string
Data string
}
func parseOpenAIImageTestSSEEvents(body string) []openAIImageTestSSEEvent {
chunks := strings.Split(body, "\n\n")
events := make([]openAIImageTestSSEEvent, 0, len(chunks))
for _, chunk := range chunks {
chunk = strings.TrimSpace(chunk)
if chunk == "" {
continue
}
var event openAIImageTestSSEEvent
for _, line := range strings.Split(chunk, "\n") {
switch {
case strings.HasPrefix(line, "event: "):
event.Name = strings.TrimSpace(strings.TrimPrefix(line, "event: "))
case strings.HasPrefix(line, "data: "):
event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data: "))
}
}
if event.Name != "" || event.Data != "" {
events = append(events, event)
}
}
return events
}
func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) (openAIImageTestSSEEvent, bool) {
for _, event := range events {
if event.Name == name {
return event, true
}
}
return openAIImageTestSSEEvent{}, false
}
func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
c.Set("api_key", &APIKey{ID: 42})
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_123"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc.httpUpstream = upstream
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
"chatgpt_account_id": "acct-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-image-2", result.Model)
require.Equal(t, "gpt-image-2", result.UpstreamModel)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 11, result.Usage.InputTokens)
require.Equal(t, 22, result.Usage.OutputTokens)
require.Equal(t, 7, result.Usage.ImageOutputTokens)
require.NotNil(t, upstream.lastReq)
require.Equal(t, chatgptCodexURL, upstream.lastReq.URL.String())
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type"))
require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept"))
require.Equal(t, "acct-123", upstream.lastReq.Header.Get("chatgpt-account-id"))
require.Equal(t, "responses=experimental", upstream.lastReq.Header.Get("OpenAI-Beta"))
require.Equal(t, openAIImagesResponsesMainModel, gjson.GetBytes(upstream.lastBody, "model").String())
require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "image_generation", gjson.GetBytes(upstream.lastBody, "tools.0.type").String())
require.Equal(t, "generate", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String())
require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String())
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists())
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
}
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000001,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\",\"background\":\"auto\"}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc.httpUpstream = upstream
account := &Account{
ID: 2,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
partial, ok := findOpenAIImageTestSSEEvent(events, "image_generation.partial_image")
require.True(t, ok)
require.Equal(t, "image_generation.partial_image", gjson.Get(partial.Data, "type").String())
require.Equal(t, int64(1710000001), gjson.Get(partial.Data, "created_at").Int())
require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String())
require.Equal(t, "data:image/png;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String())
require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String())
require.Equal(t, "png", gjson.Get(partial.Data, "output_format").String())
require.Equal(t, "high", gjson.Get(partial.Data, "quality").String())
require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String())
require.Equal(t, "auto", gjson.Get(partial.Data, "background").String())
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
require.True(t, ok)
require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String())
require.Equal(t, int64(1710000001), gjson.Get(completed.Data, "created_at").Int())
require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String())
require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String())
require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
require.Equal(t, "png", gjson.Get(completed.Data, "output_format").String())
require.Equal(t, "high", gjson.Get(completed.Data, "quality").String())
require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String())
require.Equal(t, "auto", gjson.Get(completed.Data, "background").String())
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
}
func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
gin.SetMode(gin.TestMode)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
require.NoError(t, writer.WriteField("prompt", "replace background with aurora"))
require.NoError(t, writer.WriteField("input_fidelity", "high"))
require.NoError(t, writer.WriteField("output_format", "webp"))
require.NoError(t, writer.WriteField("quality", "high"))
imageHeader := make(textproto.MIMEHeader)
imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`)
imageHeader.Set("Content-Type", "image/png")
imagePart, err := writer.CreatePart(imageHeader)
require.NoError(t, err)
_, err = imagePart.Write([]byte("png-image-content"))
require.NoError(t, err)
maskHeader := make(textproto.MIMEHeader)
maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`)
maskHeader.Set("Content-Type", "image/png")
maskPart, err := writer.CreatePart(maskHeader)
require.NoError(t, err)
_, err = maskPart.Write([]byte("png-mask-content"))
require.NoError(t, err)
require.NoError(t, writer.Close())
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
req.Header.Set("Content-Type", writer.FormDataContentType())
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
c.Set("api_key", &APIKey{ID: 100})
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
require.NoError(t, err)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_edit_123"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000002,\"usage\":{\"input_tokens\":13,\"output_tokens\":21,\"output_tokens_details\":{\"image_tokens\":8}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\",\"quality\":\"high\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc.httpUpstream = upstream
account := &Account{
ID: 3,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").Exists())
require.Equal(t, "webp", gjson.GetBytes(upstream.lastBody, "tools.0.output_format").String())
require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String(), "data:image/png;base64,"))
require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String(), "data:image/png;base64,"))
require.Equal(t, "replace background with aurora", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
require.Equal(t, "replace background with aurora", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
}
func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{
"model":"gpt-image-2",
"prompt":"replace background with aurora",
"images":[{"image_url":"https://example.com/source.png"}],
"mask":{"image_url":"https://example.com/mask.png"},
"stream":true,
"response_format":"url"
}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000003,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\",\"background\":\"transparent\"}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc.httpUpstream = upstream
account := &Account{
ID: 4,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
require.Equal(t, "https://example.com/source.png", gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String())
require.Equal(t, "https://example.com/mask.png", gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String())
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
partial, ok := findOpenAIImageTestSSEEvent(events, "image_edit.partial_image")
require.True(t, ok)
require.Equal(t, "image_edit.partial_image", gjson.Get(partial.Data, "type").String())
require.Equal(t, int64(1710000003), gjson.Get(partial.Data, "created_at").Int())
require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String())
require.Equal(t, "data:image/webp;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String())
require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String())
require.Equal(t, "webp", gjson.Get(partial.Data, "output_format").String())
require.Equal(t, "high", gjson.Get(partial.Data, "quality").String())
require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String())
require.Equal(t, "transparent", gjson.Get(partial.Data, "background").String())
completed, ok := findOpenAIImageTestSSEEvent(events, "image_edit.completed")
require.True(t, ok)
require.Equal(t, "image_edit.completed", gjson.Get(completed.Data, "type").String())
require.Equal(t, int64(1710000003), gjson.Get(completed.Data, "created_at").Int())
require.Equal(t, "ZWRpdGVk", gjson.Get(completed.Data, "b64_json").String())
require.Equal(t, "data:image/webp;base64,ZWRpdGVk", gjson.Get(completed.Data, "url").String())
require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
require.Equal(t, "webp", gjson.Get(completed.Data, "output_format").String())
require.Equal(t, "high", gjson.Get(completed.Data, "quality").String())
require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String())
require.Equal(t, "transparent", gjson.Get(completed.Data, "background").String())
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
}
func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesGenerationsEndpoint,
Model: "gpt-image-2",
Prompt: "draw a cat",
N: 2,
}
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
require.NoError(t, err)
require.NotNil(t, body)
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
}
func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesEditsEndpoint,
Model: "gpt-image-2",
Prompt: "replace background",
InputFidelity: "high",
InputImageURLs: []string{
"https://example.com/source.png",
},
}
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
require.NoError(t, err)
require.NotNil(t, body)
require.False(t, gjson.GetBytes(body, "tools.0.input_fidelity").Exists())
require.Equal(t, "edit", gjson.GetBytes(body, "tools.0.action").String())
}
func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testing.T) {
body := []byte(
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000004}}\n\n" +
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\"}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000004,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
"data: [DONE]\n\n",
)
results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body)
require.NoError(t, err)
require.True(t, foundFinal)
require.Equal(t, int64(1710000004), createdAt)
require.Len(t, results, 1)
require.Equal(t, "aGVsbG8=", results[0].Result)
require.Equal(t, "draw a cat", results[0].RevisedPrompt)
require.Equal(t, "png", firstMeta.OutputFormat)
require.JSONEq(t, `{"images":1}`, string(usageRaw))
}
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_stream_output_item_done"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000005,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc.httpUpstream = upstream
account := &Account{
ID: 5,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.Stream)
require.Equal(t, 1, result.ImageCount)
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
require.True(t, ok)
require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String())
require.Equal(t, int64(1710000005), gjson.Get(completed.Data, "created_at").Int())
require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String())
require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String())
require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
require.NotContains(t, rec.Body.String(), "event: error")
}
......@@ -794,6 +794,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
}
}
// GPT-5.5 回退到 GPT-5.4 定价
if strings.HasPrefix(model, "gpt-5.5") {
logger.With(zap.String("component", "service.pricing")).
Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)"))
return openAIGPT54FallbackPricing
}
if strings.HasPrefix(model, "gpt-5.4-mini") {
logger.With(zap.String("component", "service.pricing")).
Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)"))
......
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strconv"
......@@ -23,6 +25,7 @@ type RateLimitService struct {
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache
openAI403CounterCache OpenAI403CounterCache
settingService *SettingService
tokenCacheInvalidator TokenCacheInvalidator
usageCacheMu sync.RWMutex
......@@ -52,6 +55,12 @@ type geminiUsageTotalsBatchProvider interface {
const geminiPrecheckCacheTTL = time.Minute
const (
openAI403CooldownMinutesDefault = 10
openAI403DisableThreshold = 3
openAI403CounterWindowMinutes = 180
)
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
return &RateLimitService{
......@@ -69,6 +78,11 @@ func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) {
s.timeoutCounterCache = cache
}
// SetOpenAI403CounterCache 设置 OpenAI 403 连续失败计数器(可选依赖)
func (s *RateLimitService) SetOpenAI403CounterCache(cache OpenAI403CounterCache) {
s.openAI403CounterCache = cache
}
// SetSettingService 设置系统设置服务(可选依赖)
func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
......@@ -655,6 +669,30 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
}
func buildForbiddenErrorMessage(prefix string, upstreamMsg string, responseBody []byte, fallback string) string {
prefix = strings.TrimSpace(prefix)
if prefix != "" && !strings.HasSuffix(prefix, " ") {
prefix += " "
}
if msg := strings.TrimSpace(upstreamMsg); msg != "" {
return prefix + msg
}
rawBody := bytes.TrimSpace(responseBody)
if len(rawBody) > 0 {
if json.Valid(rawBody) {
var compact bytes.Buffer
if err := json.Compact(&compact, rawBody); err == nil {
return prefix + truncateForLog(compact.Bytes(), 512)
}
}
return prefix + truncateForLog(rawBody, 512)
}
return prefix + fallback
}
// handle403 处理 403 Forbidden 错误
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
// 其他平台保持原有 SetError 行为。
......@@ -662,13 +700,62 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst
if account.Platform == PlatformAntigravity {
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
}
if account.Platform == PlatformOpenAI {
return s.handleOpenAI403(ctx, account, upstreamMsg, responseBody)
}
// 非 Antigravity 平台:保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
msg := buildForbiddenErrorMessage(
"Access forbidden (403):",
upstreamMsg,
responseBody,
"account may be suspended or lack permissions",
)
s.handleAuthError(ctx, account, msg)
return true
}
func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
msg := buildForbiddenErrorMessage(
"Access forbidden (403):",
upstreamMsg,
responseBody,
"account may be suspended or lack permissions",
)
if s.openAI403CounterCache == nil {
s.handleAuthError(ctx, account, msg)
return true
}
count, err := s.openAI403CounterCache.IncrementOpenAI403Count(ctx, account.ID, openAI403CounterWindowMinutes)
if err != nil {
slog.Warn("openai_403_increment_failed", "account_id", account.ID, "error", err)
s.handleAuthError(ctx, account, msg)
return true
}
if count >= openAI403DisableThreshold {
msg = fmt.Sprintf("%s | consecutive_403=%d/%d", msg, count, openAI403DisableThreshold)
s.handleAuthError(ctx, account, msg)
return true
}
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
s.handleAuthError(ctx, account, msg)
return true
}
slog.Warn(
"openai_403_temp_unschedulable",
"account_id", account.ID,
"until", until,
"count", count,
"threshold", openAI403DisableThreshold,
)
return true
}
// handleAntigravity403 处理 Antigravity 平台的 403 错误
......@@ -681,10 +768,12 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
switch fbType {
case forbiddenTypeValidation:
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
msg := "Validation required (403): account needs Google verification"
if upstreamMsg != "" {
msg = "Validation required (403): " + upstreamMsg
}
msg := buildForbiddenErrorMessage(
"Validation required (403):",
upstreamMsg,
responseBody,
"account needs Google verification",
)
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
msg += " | validation_url: " + validationURL
}
......@@ -693,19 +782,23 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
case forbiddenTypeViolation:
// 违规封号: 永久禁用,需人工处理
msg := "Account violation (403): terms of service violation"
if upstreamMsg != "" {
msg = "Account violation (403): " + upstreamMsg
}
msg := buildForbiddenErrorMessage(
"Account violation (403):",
upstreamMsg,
responseBody,
"terms of service violation",
)
s.handleAuthError(ctx, account, msg)
return true
default:
// 通用 403: 保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
msg := buildForbiddenErrorMessage(
"Access forbidden (403):",
upstreamMsg,
responseBody,
"account may be suspended or lack permissions",
)
s.handleAuthError(ctx, account, msg)
return true
}
......@@ -1221,9 +1314,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
}
}
s.ResetOpenAI403Counter(ctx, accountID)
return nil
}
func (s *RateLimitService) ResetOpenAI403Counter(ctx context.Context, accountID int64) {
if s == nil || s.openAI403CounterCache == nil || accountID <= 0 {
return
}
if err := s.openAI403CounterCache.ResetOpenAI403Count(ctx, accountID); err != nil {
slog.Warn("openai_403_reset_failed", "account_id", accountID, "error", err)
}
}
// RecoverAccountState 按需恢复账号的可恢复运行时状态。
func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) {
account, err := s.accountRepo.GetByID(ctx, accountID)
......@@ -1250,6 +1353,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
}
result.ClearedRateLimit = true
}
if result.ClearedError || result.ClearedRateLimit {
s.ResetOpenAI403Counter(ctx, accountID)
}
return result, nil
}
......
......@@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct {
updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
lastTempReason string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
......@@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
r.lastTempReason = reason
return nil
}
......@@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct {
err error
}
type openAI403CounterCacheStub struct {
counts []int64
resetCalls []int64
err error
}
func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) {
if s.err != nil {
return 0, s.err
}
if len(s.counts) == 0 {
return 1, nil
}
count := s.counts[0]
s.counts = s.counts[1:]
return count, nil
}
func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
s.resetCalls = append(s.resetCalls, accountID)
return nil
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment