"frontend/src/i18n/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "9d698d9306d7c656c0ac48ba8bf3091dfc77f843"
Commit bb399e56 authored by Wang Lvyuan's avatar Wang Lvyuan
Browse files

merge: resolve upstream main conflicts for bulk OpenAI passthrough

parents 73d72651 0f033930
...@@ -61,6 +61,9 @@ temp/ ...@@ -61,6 +61,9 @@ temp/
deploy/install.sh deploy/install.sh
deploy/sub2api.service deploy/sub2api.service
deploy/sub2api-sudoers deploy/sub2api-sudoers
deploy/data/
deploy/postgres_data/
deploy/redis_data/
# GoReleaser # GoReleaser
.goreleaser.yaml .goreleaser.yaml
......
...@@ -114,6 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -114,6 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient() driveClient := repository.NewGeminiDriveClient()
......
...@@ -352,7 +352,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc ...@@ -352,7 +352,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
pageSize := dataPageCap pageSize := dataPageCap
var out []service.Account var out []service.Account
for { for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0) items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -219,6 +219,7 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -219,6 +219,7 @@ func (h *AccountHandler) List(c *gin.Context) {
accountType := c.Query("type") accountType := c.Query("type")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if len(search) > 100 {
...@@ -244,7 +245,7 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -244,7 +245,7 @@ func (h *AccountHandler) List(c *gin.Context) {
} }
} }
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID) accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -1936,7 +1937,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { ...@@ -1936,7 +1937,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0) accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 { if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0) allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "")
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int ...@@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil return nil
} }
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil return s.accounts, int64(len(s.accounts)), nil
} }
......
...@@ -110,6 +110,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -110,6 +110,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance, DefaultBalance: settings.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
...@@ -176,6 +177,7 @@ type UpdateSettingsRequest struct { ...@@ -176,6 +177,7 @@ type UpdateSettingsRequest struct {
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
...@@ -231,11 +233,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -231,11 +233,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 { if req.DefaultBalance < 0 {
req.DefaultBalance = 0 req.DefaultBalance = 0
} }
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPPassword = strings.TrimSpace(req.SMTPPassword)
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
if req.SMTPPort <= 0 { if req.SMTPPort <= 0 {
req.SMTPPort = 587 req.SMTPPort = 587
} }
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
if req.SMTPHost == "" && previousSettings.SMTPHost != "" {
req.SMTPHost = previousSettings.SMTPHost
req.SMTPPort = previousSettings.SMTPPort
req.SMTPUsername = previousSettings.SMTPUsername
req.SMTPFrom = previousSettings.SMTPFrom
req.SMTPFromName = previousSettings.SMTPFromName
req.SMTPUseTLS = previousSettings.SMTPUseTLS
}
// Turnstile 参数验证 // Turnstile 参数验证
if req.TurnstileEnabled { if req.TurnstileEnabled {
// 检查必填字段 // 检查必填字段
...@@ -417,6 +435,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -417,6 +435,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
customMenuJSON = string(menuBytes) customMenuJSON = string(menuBytes)
} }
// 自定义端点验证
const (
maxCustomEndpoints = 10
maxEndpointNameLen = 50
maxEndpointURLLen = 2048
maxEndpointDescriptionLen = 200
)
customEndpointsJSON := previousSettings.CustomEndpoints
if req.CustomEndpoints != nil {
endpoints := *req.CustomEndpoints
if len(endpoints) > maxCustomEndpoints {
response.BadRequest(c, "Too many custom endpoints (max 10)")
return
}
for _, ep := range endpoints {
if strings.TrimSpace(ep.Name) == "" {
response.BadRequest(c, "Custom endpoint name is required")
return
}
if len(ep.Name) > maxEndpointNameLen {
response.BadRequest(c, "Custom endpoint name is too long (max 50 characters)")
return
}
if strings.TrimSpace(ep.Endpoint) == "" {
response.BadRequest(c, "Custom endpoint URL is required")
return
}
if len(ep.Endpoint) > maxEndpointURLLen {
response.BadRequest(c, "Custom endpoint URL is too long (max 2048 characters)")
return
}
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(ep.Endpoint)); err != nil {
response.BadRequest(c, "Custom endpoint URL must be an absolute http(s) URL")
return
}
if len(ep.Description) > maxEndpointDescriptionLen {
response.BadRequest(c, "Custom endpoint description is too long (max 200 characters)")
return
}
}
endpointBytes, err := json.Marshal(endpoints)
if err != nil {
response.BadRequest(c, "Failed to serialize custom endpoints")
return
}
customEndpointsJSON = string(endpointBytes)
}
// Ops metrics collector interval validation (seconds). // Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil { if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds v := *req.OpsMetricsIntervalSeconds
...@@ -495,6 +562,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -495,6 +562,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionURL: purchaseURL, PurchaseSubscriptionURL: purchaseURL,
SoraClientEnabled: req.SoraClientEnabled, SoraClientEnabled: req.SoraClientEnabled,
CustomMenuItems: customMenuJSON, CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance, DefaultBalance: req.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
...@@ -592,6 +660,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -592,6 +660,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
SoraClientEnabled: updatedSettings.SoraClientEnabled, SoraClientEnabled: updatedSettings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance, DefaultBalance: updatedSettings.DefaultBalance,
DefaultSubscriptions: updatedDefaultSubscriptions, DefaultSubscriptions: updatedDefaultSubscriptions,
...@@ -828,7 +897,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { ...@@ -828,7 +897,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"` SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"` SMTPPassword string `json:"smtp_password"`
...@@ -844,18 +913,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { ...@@ -844,18 +913,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
return return
} }
if req.SMTPPort <= 0 { req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPPort = 587 req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
var savedConfig *service.SMTPConfig
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
savedConfig = cfg
} }
// 如果未提供密码,从数据库获取已保存的密码 if req.SMTPHost == "" && savedConfig != nil {
password := req.SMTPPassword req.SMTPHost = savedConfig.Host
if password == "" { }
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) if req.SMTPPort <= 0 {
if err == nil && savedConfig != nil { if savedConfig != nil && savedConfig.Port > 0 {
password = savedConfig.Password req.SMTPPort = savedConfig.Port
} else {
req.SMTPPort = 587
} }
} }
if req.SMTPUsername == "" && savedConfig != nil {
req.SMTPUsername = savedConfig.Username
}
password := strings.TrimSpace(req.SMTPPassword)
if password == "" && savedConfig != nil {
password = savedConfig.Password
}
if req.SMTPHost == "" {
response.BadRequest(c, "SMTP host is required")
return
}
config := &service.SMTPConfig{ config := &service.SMTPConfig{
Host: req.SMTPHost, Host: req.SMTPHost,
...@@ -877,7 +963,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { ...@@ -877,7 +963,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
// SendTestEmailRequest 发送测试邮件请求 // SendTestEmailRequest 发送测试邮件请求
type SendTestEmailRequest struct { type SendTestEmailRequest struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"` SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"` SMTPPassword string `json:"smtp_password"`
...@@ -895,18 +981,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) { ...@@ -895,18 +981,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
return return
} }
if req.SMTPPort <= 0 { req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPPort = 587 req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
var savedConfig *service.SMTPConfig
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
savedConfig = cfg
} }
// 如果未提供密码,从数据库获取已保存的密码 if req.SMTPHost == "" && savedConfig != nil {
password := req.SMTPPassword req.SMTPHost = savedConfig.Host
if password == "" { }
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) if req.SMTPPort <= 0 {
if err == nil && savedConfig != nil { if savedConfig != nil && savedConfig.Port > 0 {
password = savedConfig.Password req.SMTPPort = savedConfig.Port
} else {
req.SMTPPort = 587
} }
} }
if req.SMTPUsername == "" && savedConfig != nil {
req.SMTPUsername = savedConfig.Username
}
password := strings.TrimSpace(req.SMTPPassword)
if password == "" && savedConfig != nil {
password = savedConfig.Password
}
if req.SMTPFrom == "" && savedConfig != nil {
req.SMTPFrom = savedConfig.From
}
if req.SMTPFromName == "" && savedConfig != nil {
req.SMTPFromName = savedConfig.FromName
}
if req.SMTPHost == "" {
response.BadRequest(c, "SMTP host is required")
return
}
config := &service.SMTPConfig{ config := &service.SMTPConfig{
Host: req.SMTPHost, Host: req.SMTPHost,
......
...@@ -15,6 +15,13 @@ type CustomMenuItem struct { ...@@ -15,6 +15,13 @@ type CustomMenuItem struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
} }
// CustomEndpoint represents an admin-configured API endpoint for quick copy.
type CustomEndpoint struct {
Name string `json:"name"`
Endpoint string `json:"endpoint"`
Description string `json:"description"`
}
// SystemSettings represents the admin settings API response payload. // SystemSettings represents the admin settings API response payload.
type SystemSettings struct { type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
...@@ -56,6 +63,7 @@ type SystemSettings struct { ...@@ -56,6 +63,7 @@ type SystemSettings struct {
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
...@@ -114,6 +122,7 @@ type PublicSettings struct { ...@@ -114,6 +122,7 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
...@@ -218,3 +227,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem { ...@@ -218,3 +227,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem {
} }
return filtered return filtered
} }
// ParseCustomEndpoints parses a JSON string into a slice of CustomEndpoint.
// Returns empty slice on empty/invalid input.
func ParseCustomEndpoints(raw string) []CustomEndpoint {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "[]" {
return []CustomEndpoint{}
}
var items []CustomEndpoint
if err := json.Unmarshal([]byte(raw), &items); err != nil {
return []CustomEndpoint{}
}
return items
}
...@@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 验证 model 必填 // 验证 model 必填
if reqModel == "" { if reqModel == "" {
...@@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
} }
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body) setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
// 获取订阅信息(可能为nil) // 获取订阅信息(可能为nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
......
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups.
// POST /v1/chat/completions
// This converts Chat Completions requests to Anthropic format (via Responses format chain),
// forwards to Anthropic upstream, and converts responses back to Chat Completions format.
func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.chatCompletionsErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleCCFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.cc.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format.
func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// handleCCFailoverExhausted writes a failover-exhausted error in CC format.
func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// Responses handles OpenAI Responses API endpoint for Anthropic platform groups.
// POST /v1/responses
// This converts Responses API requests to Anthropic format, forwards to Anthropic
// upstream, and converts responses back to Responses format.
func (h *GatewayHandler) Responses(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream using gjson (like OpenAI handler)
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected.
// The existing service-layer checkClaudeCodeRestriction handles degradation
// to fallback groups when the Forward path calls SelectAccountForModelWithExclusions.
// Here we just reject at handler level since /v1/responses clients can't be Claude Code.
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.responsesErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.responsesErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "responses")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// Can't failover if streaming content already sent
if c.Writer.Size() != writerSizeBeforeForward {
h.handleResponsesFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.responses.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// responsesErrorResponse writes an error in OpenAI Responses API format.
func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": code,
"message": message,
},
})
}
// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format.
func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return // Can't write error after stream started
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}
...@@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
setOpsRequestContext(c, modelName, stream, body) setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// Get subscription (may be nil) // Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware.GetSubscriptionFromContext(c)
......
...@@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService) service.BindErrorPassthroughService(c, h.errorPassthroughService)
......
...@@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) { if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
...@@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
...@@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
zap.String("previous_response_id_kind", previousResponseIDKind), zap.String("previous_response_id_kind", previousResponseIDKind),
) )
setOpsRequestContext(c, reqModel, true, firstMessage) setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
var currentUserRelease func() var currentUserRelease func()
var currentAccountRelease func() var currentAccountRelease func()
......
...@@ -27,6 +27,9 @@ const ( ...@@ -27,6 +27,9 @@ const (
opsRequestBodyKey = "ops_request_body" opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id" opsAccountIDKey = "ops_account_id"
opsUpstreamModelKey = "ops_upstream_model"
opsRequestTypeKey = "ops_request_type"
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用 // 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
opsErrContextCanceled = "context canceled" opsErrContextCanceled = "context canceled"
opsErrNoAvailableAccounts = "no available accounts" opsErrNoAvailableAccounts = "no available accounts"
...@@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody ...@@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
} }
} }
// setOpsEndpointContext stores upstream model and request type for ops error logging.
// Called by handlers after model mapping and request type determination.
func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) {
if c == nil {
return
}
if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" {
c.Set(opsUpstreamModelKey, upstreamModel)
}
c.Set(opsRequestTypeKey, requestType)
}
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil { if c == nil || entry == nil {
return return
...@@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ...@@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
} }
return "" return ""
}(), }(),
Stream: stream, Stream: stream,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
RequestedModel: modelName,
UpstreamModel: func() string {
if v, ok := c.Get(opsUpstreamModelKey); ok {
if s, ok := v.(string); ok {
return strings.TrimSpace(s)
}
}
return ""
}(),
RequestType: func() *int16 {
if v, ok := c.Get(opsRequestTypeKey); ok {
switch t := v.(type) {
case int16:
return &t
case int:
v16 := int16(t)
return &v16
}
}
return nil
}(),
UserAgent: c.GetHeader("User-Agent"), UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: "upstream", ErrorPhase: "upstream",
...@@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ...@@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
} }
return "" return ""
}(), }(),
Stream: stream, Stream: stream,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
RequestedModel: modelName,
UpstreamModel: func() string {
if v, ok := c.Get(opsUpstreamModelKey); ok {
if s, ok := v.(string); ok {
return strings.TrimSpace(s)
}
}
return ""
}(),
RequestType: func() *int16 {
if v, ok := c.Get(opsRequestTypeKey); ok {
switch t := v.(type) {
case int16:
return &t
case int:
v16 := int16(t)
return &v16
}
}
return nil
}(),
UserAgent: c.GetHeader("User-Agent"), UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: phase, ErrorPhase: phase,
......
...@@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) { ...@@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) {
}) })
} }
} }
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
setOpsEndpointContext(c, "claude-3-5-sonnet-20241022", int16(2)) // stream
v, ok := c.Get(opsUpstreamModelKey)
require.True(t, ok)
vStr, ok := v.(string)
require.True(t, ok)
require.Equal(t, "claude-3-5-sonnet-20241022", vStr)
rt, ok := c.Get(opsRequestTypeKey)
require.True(t, ok)
rtVal, ok := rt.(int16)
require.True(t, ok)
require.Equal(t, int16(2), rtVal)
}
func TestSetOpsEndpointContext_EmptyModelNotStored(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
setOpsEndpointContext(c, "", int16(1))
_, ok := c.Get(opsUpstreamModelKey)
require.False(t, ok, "empty upstream model should not be stored")
rt, ok := c.Get(opsRequestTypeKey)
require.True(t, ok)
rtVal, ok := rt.(int16)
require.True(t, ok)
require.Equal(t, int16(1), rtVal)
}
func TestSetOpsEndpointContext_NilContext(t *testing.T) {
require.NotPanics(t, func() {
setOpsEndpointContext(nil, "model", int16(1))
})
}
...@@ -52,6 +52,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -52,6 +52,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
......
...@@ -2072,7 +2072,7 @@ func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error ...@@ -2072,7 +2072,7 @@ func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error
func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
......
...@@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
} }
setOpsRequestContext(c, reqModel, clientStream, body) setOpsRequestContext(c, reqModel, clientStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
platform := "" platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
......
...@@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error ...@@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
......
package apicompat
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"time"
)
// ---------------------------------------------------------------------------
// Non-streaming: AnthropicResponse → ResponsesResponse
// ---------------------------------------------------------------------------
// AnthropicToResponsesResponse converts an Anthropic Messages response into a
// Responses API response. This is the reverse of ResponsesToAnthropic and
// enables Anthropic upstream responses to be returned in OpenAI Responses format.
func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse {
id := resp.ID
if id == "" {
id = generateResponsesID()
}
out := &ResponsesResponse{
ID: id,
Object: "response",
Model: resp.Model,
}
var outputs []ResponsesOutput
var msgParts []ResponsesContentPart
for _, block := range resp.Content {
switch block.Type {
case "thinking":
if block.Thinking != "" {
outputs = append(outputs, ResponsesOutput{
Type: "reasoning",
ID: generateItemID(),
Summary: []ResponsesSummary{{
Type: "summary_text",
Text: block.Thinking,
}},
})
}
case "text":
if block.Text != "" {
msgParts = append(msgParts, ResponsesContentPart{
Type: "output_text",
Text: block.Text,
})
}
case "tool_use":
args := "{}"
if len(block.Input) > 0 {
args = string(block.Input)
}
outputs = append(outputs, ResponsesOutput{
Type: "function_call",
ID: generateItemID(),
CallID: toResponsesCallID(block.ID),
Name: block.Name,
Arguments: args,
Status: "completed",
})
}
}
// Assemble message output item from text parts
if len(msgParts) > 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: msgParts,
Status: "completed",
})
}
if len(outputs) == 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: []ResponsesContentPart{{Type: "output_text", Text: ""}},
Status: "completed",
})
}
out.Output = outputs
// Map stop_reason → status
out.Status = anthropicStopReasonToResponsesStatus(resp.StopReason, resp.Content)
if out.Status == "incomplete" {
out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
}
// Usage
out.Usage = &ResponsesUsage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.CacheReadInputTokens > 0 {
out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{
CachedTokens: resp.Usage.CacheReadInputTokens,
}
}
return out
}
// anthropicStopReasonToResponsesStatus maps Anthropic stop_reason to Responses status.
func anthropicStopReasonToResponsesStatus(stopReason string, blocks []AnthropicContentBlock) string {
switch stopReason {
case "max_tokens":
return "incomplete"
case "end_turn", "tool_use", "stop_sequence":
return "completed"
default:
return "completed"
}
}
// ---------------------------------------------------------------------------
// Streaming: AnthropicStreamEvent → []ResponsesStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
// AnthropicEventToResponsesState tracks state for converting a sequence of
// Anthropic SSE events into Responses SSE events.
type AnthropicEventToResponsesState struct {
ResponseID string
Model string
Created int64
SequenceNumber int
// CreatedSent tracks whether response.created has been emitted.
CreatedSent bool
// CompletedSent tracks whether the terminal event has been emitted.
CompletedSent bool
// Current output tracking
OutputIndex int
CurrentItemID string
CurrentItemType string // "message" | "function_call" | "reasoning"
// For message output: accumulate text parts
ContentIndex int
// For function_call: track per-output info
CurrentCallID string
CurrentName string
// Usage from message_delta
InputTokens int
OutputTokens int
CacheReadInputTokens int
}
// NewAnthropicEventToResponsesState returns an initialised stream state.
func NewAnthropicEventToResponsesState() *AnthropicEventToResponsesState {
return &AnthropicEventToResponsesState{
Created: time.Now().Unix(),
}
}
// AnthropicEventToResponsesEvents converts a single Anthropic SSE event into
// zero or more Responses SSE events, updating state as it goes.
func AnthropicEventToResponsesEvents(
evt *AnthropicStreamEvent,
state *AnthropicEventToResponsesState,
) []ResponsesStreamEvent {
switch evt.Type {
case "message_start":
return anthToResHandleMessageStart(evt, state)
case "content_block_start":
return anthToResHandleContentBlockStart(evt, state)
case "content_block_delta":
return anthToResHandleContentBlockDelta(evt, state)
case "content_block_stop":
return anthToResHandleContentBlockStop(evt, state)
case "message_delta":
return anthToResHandleMessageDelta(evt, state)
case "message_stop":
return anthToResHandleMessageStop(state)
default:
return nil
}
}
// FinalizeAnthropicResponsesStream emits synthetic termination events if the
// stream ended without a proper message_stop.
func FinalizeAnthropicResponsesStream(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if !state.CreatedSent || state.CompletedSent {
return nil
}
var events []ResponsesStreamEvent
// Close any open item
events = append(events, closeCurrentResponsesItem(state)...)
// Emit response.completed
events = append(events, makeResponsesCompletedEvent(state, "completed", nil))
state.CompletedSent = true
return events
}
// ResponsesEventToSSE formats a ResponsesStreamEvent as an SSE data line.
func ResponsesEventToSSE(evt ResponsesStreamEvent) (string, error) {
data, err := json.Marshal(evt)
if err != nil {
return "", err
}
return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil
}
// --- internal handlers ---
func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.Message != nil {
state.ResponseID = evt.Message.ID
if state.Model == "" {
state.Model = evt.Message.Model
}
if evt.Message.Usage.InputTokens > 0 {
state.InputTokens = evt.Message.Usage.InputTokens
}
}
if state.CreatedSent {
return nil
}
state.CreatedSent = true
// Emit response.created
return []ResponsesStreamEvent{makeResponsesCreatedEvent(state)}
}
func anthToResHandleContentBlockStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.ContentBlock == nil {
return nil
}
var events []ResponsesStreamEvent
switch evt.ContentBlock.Type {
case "thinking":
state.CurrentItemID = generateItemID()
state.CurrentItemType = "reasoning"
state.ContentIndex = 0
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "reasoning",
ID: state.CurrentItemID,
},
}))
case "text":
// If we don't have an open message item, open one
if state.CurrentItemType != "message" {
state.CurrentItemID = generateItemID()
state.CurrentItemType = "message"
state.ContentIndex = 0
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "message",
ID: state.CurrentItemID,
Role: "assistant",
Status: "in_progress",
},
}))
}
case "tool_use":
// Close previous item if any
events = append(events, closeCurrentResponsesItem(state)...)
state.CurrentItemID = generateItemID()
state.CurrentItemType = "function_call"
state.CurrentCallID = toResponsesCallID(evt.ContentBlock.ID)
state.CurrentName = evt.ContentBlock.Name
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "function_call",
ID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
Status: "in_progress",
},
}))
}
return events
}
func anthToResHandleContentBlockDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.Delta == nil {
return nil
}
switch evt.Delta.Type {
case "text_delta":
if evt.Delta.Text == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ContentIndex: state.ContentIndex,
Delta: evt.Delta.Text,
ItemID: state.CurrentItemID,
})}
case "thinking_delta":
if evt.Delta.Thinking == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
SummaryIndex: 0,
Delta: evt.Delta.Thinking,
ItemID: state.CurrentItemID,
})}
case "input_json_delta":
if evt.Delta.PartialJSON == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Delta: evt.Delta.PartialJSON,
ItemID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
})}
case "signature_delta":
// Anthropic signature deltas have no Responses equivalent; skip
return nil
}
return nil
}
func anthToResHandleContentBlockStop(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
switch state.CurrentItemType {
case "reasoning":
// Emit reasoning summary done + output item done
events := []ResponsesStreamEvent{
makeResponsesEvent(state, "response.reasoning_summary_text.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
SummaryIndex: 0,
ItemID: state.CurrentItemID,
}),
}
events = append(events, closeCurrentResponsesItem(state)...)
return events
case "function_call":
// Emit function_call_arguments.done + output item done
events := []ResponsesStreamEvent{
makeResponsesEvent(state, "response.function_call_arguments.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ItemID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
}),
}
events = append(events, closeCurrentResponsesItem(state)...)
return events
case "message":
// Emit output_text.done (text block is done, but message item stays open for potential more blocks)
return []ResponsesStreamEvent{
makeResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ContentIndex: state.ContentIndex,
ItemID: state.CurrentItemID,
}),
}
}
return nil
}
func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
// Update usage
if evt.Usage != nil {
state.OutputTokens = evt.Usage.OutputTokens
if evt.Usage.CacheReadInputTokens > 0 {
state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens
}
}
return nil
}
func anthToResHandleMessageStop(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if state.CompletedSent {
return nil
}
var events []ResponsesStreamEvent
// Close any open item
events = append(events, closeCurrentResponsesItem(state)...)
// Determine status
status := "completed"
var incompleteDetails *ResponsesIncompleteDetails
// Emit response.completed
events = append(events, makeResponsesCompletedEvent(state, status, incompleteDetails))
state.CompletedSent = true
return events
}
// --- helper functions ---
func closeCurrentResponsesItem(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if state.CurrentItemType == "" {
return nil
}
itemType := state.CurrentItemType
itemID := state.CurrentItemID
// Reset
state.CurrentItemType = ""
state.CurrentItemID = ""
state.CurrentCallID = ""
state.CurrentName = ""
state.OutputIndex++
state.ContentIndex = 0
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex - 1, // Use the index before increment
Item: &ResponsesOutput{
Type: itemType,
ID: itemID,
Status: "completed",
},
})}
}
func makeResponsesCreatedEvent(state *AnthropicEventToResponsesState) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
return ResponsesStreamEvent{
Type: "response.created",
SequenceNumber: seq,
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: "in_progress",
Output: []ResponsesOutput{},
},
}
}
func makeResponsesCompletedEvent(
state *AnthropicEventToResponsesState,
status string,
incompleteDetails *ResponsesIncompleteDetails,
) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
usage := &ResponsesUsage{
InputTokens: state.InputTokens,
OutputTokens: state.OutputTokens,
TotalTokens: state.InputTokens + state.OutputTokens,
}
if state.CacheReadInputTokens > 0 {
usage.InputTokensDetails = &ResponsesInputTokensDetails{
CachedTokens: state.CacheReadInputTokens,
}
}
return ResponsesStreamEvent{
Type: "response.completed",
SequenceNumber: seq,
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: status,
Output: []ResponsesOutput{}, // Simplified; full output tracking would add complexity
Usage: usage,
IncompleteDetails: incompleteDetails,
},
}
}
func makeResponsesEvent(state *AnthropicEventToResponsesState, eventType string, template *ResponsesStreamEvent) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
evt := *template
evt.Type = eventType
evt.SequenceNumber = seq
return evt
}
func generateResponsesID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "resp_" + hex.EncodeToString(b)
}
func generateItemID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "item_" + hex.EncodeToString(b)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment