Unverified Commit fa68cbad authored by InCerryGit's avatar InCerryGit Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents 995ef134 0f033930
...@@ -61,6 +61,9 @@ temp/ ...@@ -61,6 +61,9 @@ temp/
deploy/install.sh deploy/install.sh
deploy/sub2api.service deploy/sub2api.service
deploy/sub2api-sudoers deploy/sub2api-sudoers
deploy/data/
deploy/postgres_data/
deploy/redis_data/
# GoReleaser # GoReleaser
.goreleaser.yaml .goreleaser.yaml
......
...@@ -114,6 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -114,6 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient() driveClient := repository.NewGeminiDriveClient()
......
...@@ -352,7 +352,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc ...@@ -352,7 +352,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
pageSize := dataPageCap pageSize := dataPageCap
var out []service.Account var out []service.Account
for { for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0) items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -219,6 +219,7 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -219,6 +219,7 @@ func (h *AccountHandler) List(c *gin.Context) {
accountType := c.Query("type") accountType := c.Query("type")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if len(search) > 100 {
...@@ -244,7 +245,7 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -244,7 +245,7 @@ func (h *AccountHandler) List(c *gin.Context) {
} }
} }
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID) accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -1936,7 +1937,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { ...@@ -1936,7 +1937,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0) accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 { if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0) allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "")
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int ...@@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil return nil
} }
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil return s.accounts, int64(len(s.accounts)), nil
} }
......
...@@ -233,11 +233,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -233,11 +233,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 { if req.DefaultBalance < 0 {
req.DefaultBalance = 0 req.DefaultBalance = 0
} }
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPPassword = strings.TrimSpace(req.SMTPPassword)
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
if req.SMTPPort <= 0 { if req.SMTPPort <= 0 {
req.SMTPPort = 587 req.SMTPPort = 587
} }
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
if req.SMTPHost == "" && previousSettings.SMTPHost != "" {
req.SMTPHost = previousSettings.SMTPHost
req.SMTPPort = previousSettings.SMTPPort
req.SMTPUsername = previousSettings.SMTPUsername
req.SMTPFrom = previousSettings.SMTPFrom
req.SMTPFromName = previousSettings.SMTPFromName
req.SMTPUseTLS = previousSettings.SMTPUseTLS
}
// Turnstile 参数验证 // Turnstile 参数验证
if req.TurnstileEnabled { if req.TurnstileEnabled {
// 检查必填字段 // 检查必填字段
...@@ -881,7 +897,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { ...@@ -881,7 +897,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"` SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"` SMTPPassword string `json:"smtp_password"`
...@@ -897,18 +913,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { ...@@ -897,18 +913,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
return return
} }
if req.SMTPPort <= 0 { req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPPort = 587 req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
var savedConfig *service.SMTPConfig
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
savedConfig = cfg
} }
// 如果未提供密码,从数据库获取已保存的密码 if req.SMTPHost == "" && savedConfig != nil {
password := req.SMTPPassword req.SMTPHost = savedConfig.Host
if password == "" { }
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) if req.SMTPPort <= 0 {
if err == nil && savedConfig != nil { if savedConfig != nil && savedConfig.Port > 0 {
password = savedConfig.Password req.SMTPPort = savedConfig.Port
} else {
req.SMTPPort = 587
} }
} }
if req.SMTPUsername == "" && savedConfig != nil {
req.SMTPUsername = savedConfig.Username
}
password := strings.TrimSpace(req.SMTPPassword)
if password == "" && savedConfig != nil {
password = savedConfig.Password
}
if req.SMTPHost == "" {
response.BadRequest(c, "SMTP host is required")
return
}
config := &service.SMTPConfig{ config := &service.SMTPConfig{
Host: req.SMTPHost, Host: req.SMTPHost,
...@@ -930,7 +963,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { ...@@ -930,7 +963,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
// SendTestEmailRequest 发送测试邮件请求 // SendTestEmailRequest 发送测试邮件请求
type SendTestEmailRequest struct { type SendTestEmailRequest struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"` SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"` SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"` SMTPPassword string `json:"smtp_password"`
...@@ -948,18 +981,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) { ...@@ -948,18 +981,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
return return
} }
if req.SMTPPort <= 0 { req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPPort = 587 req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPFrom = strings.TrimSpace(req.SMTPFrom)
req.SMTPFromName = strings.TrimSpace(req.SMTPFromName)
var savedConfig *service.SMTPConfig
if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil {
savedConfig = cfg
} }
// 如果未提供密码,从数据库获取已保存的密码 if req.SMTPHost == "" && savedConfig != nil {
password := req.SMTPPassword req.SMTPHost = savedConfig.Host
if password == "" { }
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) if req.SMTPPort <= 0 {
if err == nil && savedConfig != nil { if savedConfig != nil && savedConfig.Port > 0 {
password = savedConfig.Password req.SMTPPort = savedConfig.Port
} else {
req.SMTPPort = 587
} }
} }
if req.SMTPUsername == "" && savedConfig != nil {
req.SMTPUsername = savedConfig.Username
}
password := strings.TrimSpace(req.SMTPPassword)
if password == "" && savedConfig != nil {
password = savedConfig.Password
}
if req.SMTPFrom == "" && savedConfig != nil {
req.SMTPFrom = savedConfig.From
}
if req.SMTPFromName == "" && savedConfig != nil {
req.SMTPFromName = savedConfig.FromName
}
if req.SMTPHost == "" {
response.BadRequest(c, "SMTP host is required")
return
}
config := &service.SMTPConfig{ config := &service.SMTPConfig{
Host: req.SMTPHost, Host: req.SMTPHost,
......
...@@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 验证 model 必填 // 验证 model 必填
if reqModel == "" { if reqModel == "" {
...@@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
} }
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body) setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
// 获取订阅信息(可能为nil) // 获取订阅信息(可能为nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
......
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups.
// POST /v1/chat/completions
// This converts Chat Completions requests to Anthropic format (via Responses format chain),
// forwards to Anthropic upstream, and converts responses back to Chat Completions format.
func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.chatCompletionsErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleCCFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.cc.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format.
func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// handleCCFailoverExhausted writes a failover-exhausted error in CC format.
func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// Responses handles OpenAI Responses API endpoint for Anthropic platform groups.
// POST /v1/responses
// This converts Responses API requests to Anthropic format, forwards to Anthropic
// upstream, and converts responses back to Responses format.
func (h *GatewayHandler) Responses(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream using gjson (like OpenAI handler)
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected.
// The existing service-layer checkClaudeCodeRestriction handles degradation
// to fallback groups when the Forward path calls SelectAccountForModelWithExclusions.
// Here we just reject at handler level since /v1/responses clients can't be Claude Code.
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.responsesErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.responsesErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "responses")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// Can't failover if streaming content already sent
if c.Writer.Size() != writerSizeBeforeForward {
h.handleResponsesFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.responses.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// responsesErrorResponse writes an error in OpenAI Responses API format.
func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": code,
"message": message,
},
})
}
// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format.
func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return // Can't write error after stream started
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}
...@@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
setOpsRequestContext(c, modelName, stream, body) setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// Get subscription (may be nil) // Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware.GetSubscriptionFromContext(c)
......
...@@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService) service.BindErrorPassthroughService(c, h.errorPassthroughService)
......
...@@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) { if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
...@@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
...@@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
zap.String("previous_response_id_kind", previousResponseIDKind), zap.String("previous_response_id_kind", previousResponseIDKind),
) )
setOpsRequestContext(c, reqModel, true, firstMessage) setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
var currentUserRelease func() var currentUserRelease func()
var currentAccountRelease func() var currentAccountRelease func()
......
...@@ -27,6 +27,9 @@ const ( ...@@ -27,6 +27,9 @@ const (
opsRequestBodyKey = "ops_request_body" opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id" opsAccountIDKey = "ops_account_id"
opsUpstreamModelKey = "ops_upstream_model"
opsRequestTypeKey = "ops_request_type"
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用 // 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
opsErrContextCanceled = "context canceled" opsErrContextCanceled = "context canceled"
opsErrNoAvailableAccounts = "no available accounts" opsErrNoAvailableAccounts = "no available accounts"
...@@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody ...@@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
} }
} }
// setOpsEndpointContext stores upstream model and request type for ops error logging.
// Called by handlers after model mapping and request type determination.
func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) {
if c == nil {
return
}
if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" {
c.Set(opsUpstreamModelKey, upstreamModel)
}
c.Set(opsRequestTypeKey, requestType)
}
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil { if c == nil || entry == nil {
return return
...@@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ...@@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
} }
return "" return ""
}(), }(),
Stream: stream, Stream: stream,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
RequestedModel: modelName,
UpstreamModel: func() string {
if v, ok := c.Get(opsUpstreamModelKey); ok {
if s, ok := v.(string); ok {
return strings.TrimSpace(s)
}
}
return ""
}(),
RequestType: func() *int16 {
if v, ok := c.Get(opsRequestTypeKey); ok {
switch t := v.(type) {
case int16:
return &t
case int:
v16 := int16(t)
return &v16
}
}
return nil
}(),
UserAgent: c.GetHeader("User-Agent"), UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: "upstream", ErrorPhase: "upstream",
...@@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ...@@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
} }
return "" return ""
}(), }(),
Stream: stream, Stream: stream,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, platform),
RequestedModel: modelName,
UpstreamModel: func() string {
if v, ok := c.Get(opsUpstreamModelKey); ok {
if s, ok := v.(string); ok {
return strings.TrimSpace(s)
}
}
return ""
}(),
RequestType: func() *int16 {
if v, ok := c.Get(opsRequestTypeKey); ok {
switch t := v.(type) {
case int16:
return &t
case int:
v16 := int16(t)
return &v16
}
}
return nil
}(),
UserAgent: c.GetHeader("User-Agent"), UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: phase, ErrorPhase: phase,
......
...@@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) { ...@@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) {
}) })
} }
} }
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
setOpsEndpointContext(c, "claude-3-5-sonnet-20241022", int16(2)) // stream
v, ok := c.Get(opsUpstreamModelKey)
require.True(t, ok)
vStr, ok := v.(string)
require.True(t, ok)
require.Equal(t, "claude-3-5-sonnet-20241022", vStr)
rt, ok := c.Get(opsRequestTypeKey)
require.True(t, ok)
rtVal, ok := rt.(int16)
require.True(t, ok)
require.Equal(t, int16(2), rtVal)
}
func TestSetOpsEndpointContext_EmptyModelNotStored(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
setOpsEndpointContext(c, "", int16(1))
_, ok := c.Get(opsUpstreamModelKey)
require.False(t, ok, "empty upstream model should not be stored")
rt, ok := c.Get(opsRequestTypeKey)
require.True(t, ok)
rtVal, ok := rt.(int16)
require.True(t, ok)
require.Equal(t, int16(1), rtVal)
}
func TestSetOpsEndpointContext_NilContext(t *testing.T) {
require.NotPanics(t, func() {
setOpsEndpointContext(nil, "model", int16(1))
})
}
...@@ -2072,7 +2072,7 @@ func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error ...@@ -2072,7 +2072,7 @@ func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error
func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
......
...@@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
} }
setOpsRequestContext(c, reqModel, clientStream, body) setOpsRequestContext(c, reqModel, clientStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
platform := "" platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
......
...@@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error ...@@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
......
package apicompat
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"time"
)
// ---------------------------------------------------------------------------
// Non-streaming: AnthropicResponse → ResponsesResponse
// ---------------------------------------------------------------------------
// AnthropicToResponsesResponse converts an Anthropic Messages response into a
// Responses API response. This is the reverse of ResponsesToAnthropic and
// enables Anthropic upstream responses to be returned in OpenAI Responses format.
func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse {
id := resp.ID
if id == "" {
id = generateResponsesID()
}
out := &ResponsesResponse{
ID: id,
Object: "response",
Model: resp.Model,
}
var outputs []ResponsesOutput
var msgParts []ResponsesContentPart
for _, block := range resp.Content {
switch block.Type {
case "thinking":
if block.Thinking != "" {
outputs = append(outputs, ResponsesOutput{
Type: "reasoning",
ID: generateItemID(),
Summary: []ResponsesSummary{{
Type: "summary_text",
Text: block.Thinking,
}},
})
}
case "text":
if block.Text != "" {
msgParts = append(msgParts, ResponsesContentPart{
Type: "output_text",
Text: block.Text,
})
}
case "tool_use":
args := "{}"
if len(block.Input) > 0 {
args = string(block.Input)
}
outputs = append(outputs, ResponsesOutput{
Type: "function_call",
ID: generateItemID(),
CallID: toResponsesCallID(block.ID),
Name: block.Name,
Arguments: args,
Status: "completed",
})
}
}
// Assemble message output item from text parts
if len(msgParts) > 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: msgParts,
Status: "completed",
})
}
if len(outputs) == 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: []ResponsesContentPart{{Type: "output_text", Text: ""}},
Status: "completed",
})
}
out.Output = outputs
// Map stop_reason → status
out.Status = anthropicStopReasonToResponsesStatus(resp.StopReason, resp.Content)
if out.Status == "incomplete" {
out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
}
// Usage
out.Usage = &ResponsesUsage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.CacheReadInputTokens > 0 {
out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{
CachedTokens: resp.Usage.CacheReadInputTokens,
}
}
return out
}
// anthropicStopReasonToResponsesStatus maps Anthropic stop_reason to Responses status.
func anthropicStopReasonToResponsesStatus(stopReason string, blocks []AnthropicContentBlock) string {
switch stopReason {
case "max_tokens":
return "incomplete"
case "end_turn", "tool_use", "stop_sequence":
return "completed"
default:
return "completed"
}
}
// ---------------------------------------------------------------------------
// Streaming: AnthropicStreamEvent → []ResponsesStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
// AnthropicEventToResponsesState tracks state for converting a sequence of
// Anthropic SSE events into Responses SSE events.
type AnthropicEventToResponsesState struct {
ResponseID string
Model string
Created int64
SequenceNumber int
// CreatedSent tracks whether response.created has been emitted.
CreatedSent bool
// CompletedSent tracks whether the terminal event has been emitted.
CompletedSent bool
// Current output tracking
OutputIndex int
CurrentItemID string
CurrentItemType string // "message" | "function_call" | "reasoning"
// For message output: accumulate text parts
ContentIndex int
// For function_call: track per-output info
CurrentCallID string
CurrentName string
// Usage from message_delta
InputTokens int
OutputTokens int
CacheReadInputTokens int
}
// NewAnthropicEventToResponsesState returns an initialised stream state.
func NewAnthropicEventToResponsesState() *AnthropicEventToResponsesState {
return &AnthropicEventToResponsesState{
Created: time.Now().Unix(),
}
}
// AnthropicEventToResponsesEvents converts a single Anthropic SSE event into
// zero or more Responses SSE events, updating state as it goes.
func AnthropicEventToResponsesEvents(
evt *AnthropicStreamEvent,
state *AnthropicEventToResponsesState,
) []ResponsesStreamEvent {
switch evt.Type {
case "message_start":
return anthToResHandleMessageStart(evt, state)
case "content_block_start":
return anthToResHandleContentBlockStart(evt, state)
case "content_block_delta":
return anthToResHandleContentBlockDelta(evt, state)
case "content_block_stop":
return anthToResHandleContentBlockStop(evt, state)
case "message_delta":
return anthToResHandleMessageDelta(evt, state)
case "message_stop":
return anthToResHandleMessageStop(state)
default:
return nil
}
}
// FinalizeAnthropicResponsesStream emits synthetic termination events if the
// stream ended without a proper message_stop.
func FinalizeAnthropicResponsesStream(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if !state.CreatedSent || state.CompletedSent {
return nil
}
var events []ResponsesStreamEvent
// Close any open item
events = append(events, closeCurrentResponsesItem(state)...)
// Emit response.completed
events = append(events, makeResponsesCompletedEvent(state, "completed", nil))
state.CompletedSent = true
return events
}
// ResponsesEventToSSE formats a ResponsesStreamEvent as an SSE data line.
func ResponsesEventToSSE(evt ResponsesStreamEvent) (string, error) {
data, err := json.Marshal(evt)
if err != nil {
return "", err
}
return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil
}
// --- internal handlers ---
func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.Message != nil {
state.ResponseID = evt.Message.ID
if state.Model == "" {
state.Model = evt.Message.Model
}
if evt.Message.Usage.InputTokens > 0 {
state.InputTokens = evt.Message.Usage.InputTokens
}
}
if state.CreatedSent {
return nil
}
state.CreatedSent = true
// Emit response.created
return []ResponsesStreamEvent{makeResponsesCreatedEvent(state)}
}
func anthToResHandleContentBlockStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.ContentBlock == nil {
return nil
}
var events []ResponsesStreamEvent
switch evt.ContentBlock.Type {
case "thinking":
state.CurrentItemID = generateItemID()
state.CurrentItemType = "reasoning"
state.ContentIndex = 0
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "reasoning",
ID: state.CurrentItemID,
},
}))
case "text":
// If we don't have an open message item, open one
if state.CurrentItemType != "message" {
state.CurrentItemID = generateItemID()
state.CurrentItemType = "message"
state.ContentIndex = 0
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "message",
ID: state.CurrentItemID,
Role: "assistant",
Status: "in_progress",
},
}))
}
case "tool_use":
// Close previous item if any
events = append(events, closeCurrentResponsesItem(state)...)
state.CurrentItemID = generateItemID()
state.CurrentItemType = "function_call"
state.CurrentCallID = toResponsesCallID(evt.ContentBlock.ID)
state.CurrentName = evt.ContentBlock.Name
events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Item: &ResponsesOutput{
Type: "function_call",
ID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
Status: "in_progress",
},
}))
}
return events
}
func anthToResHandleContentBlockDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if evt.Delta == nil {
return nil
}
switch evt.Delta.Type {
case "text_delta":
if evt.Delta.Text == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ContentIndex: state.ContentIndex,
Delta: evt.Delta.Text,
ItemID: state.CurrentItemID,
})}
case "thinking_delta":
if evt.Delta.Thinking == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
SummaryIndex: 0,
Delta: evt.Delta.Thinking,
ItemID: state.CurrentItemID,
})}
case "input_json_delta":
if evt.Delta.PartialJSON == "" {
return nil
}
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
Delta: evt.Delta.PartialJSON,
ItemID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
})}
case "signature_delta":
// Anthropic signature deltas have no Responses equivalent; skip
return nil
}
return nil
}
func anthToResHandleContentBlockStop(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
switch state.CurrentItemType {
case "reasoning":
// Emit reasoning summary done + output item done
events := []ResponsesStreamEvent{
makeResponsesEvent(state, "response.reasoning_summary_text.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
SummaryIndex: 0,
ItemID: state.CurrentItemID,
}),
}
events = append(events, closeCurrentResponsesItem(state)...)
return events
case "function_call":
// Emit function_call_arguments.done + output item done
events := []ResponsesStreamEvent{
makeResponsesEvent(state, "response.function_call_arguments.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ItemID: state.CurrentItemID,
CallID: state.CurrentCallID,
Name: state.CurrentName,
}),
}
events = append(events, closeCurrentResponsesItem(state)...)
return events
case "message":
// Emit output_text.done (text block is done, but message item stays open for potential more blocks)
return []ResponsesStreamEvent{
makeResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex,
ContentIndex: state.ContentIndex,
ItemID: state.CurrentItemID,
}),
}
}
return nil
}
func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
// Update usage
if evt.Usage != nil {
state.OutputTokens = evt.Usage.OutputTokens
if evt.Usage.CacheReadInputTokens > 0 {
state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens
}
}
return nil
}
func anthToResHandleMessageStop(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if state.CompletedSent {
return nil
}
var events []ResponsesStreamEvent
// Close any open item
events = append(events, closeCurrentResponsesItem(state)...)
// Determine status
status := "completed"
var incompleteDetails *ResponsesIncompleteDetails
// Emit response.completed
events = append(events, makeResponsesCompletedEvent(state, status, incompleteDetails))
state.CompletedSent = true
return events
}
// --- helper functions ---
func closeCurrentResponsesItem(state *AnthropicEventToResponsesState) []ResponsesStreamEvent {
if state.CurrentItemType == "" {
return nil
}
itemType := state.CurrentItemType
itemID := state.CurrentItemID
// Reset
state.CurrentItemType = ""
state.CurrentItemID = ""
state.CurrentCallID = ""
state.CurrentName = ""
state.OutputIndex++
state.ContentIndex = 0
return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{
OutputIndex: state.OutputIndex - 1, // Use the index before increment
Item: &ResponsesOutput{
Type: itemType,
ID: itemID,
Status: "completed",
},
})}
}
func makeResponsesCreatedEvent(state *AnthropicEventToResponsesState) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
return ResponsesStreamEvent{
Type: "response.created",
SequenceNumber: seq,
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: "in_progress",
Output: []ResponsesOutput{},
},
}
}
func makeResponsesCompletedEvent(
state *AnthropicEventToResponsesState,
status string,
incompleteDetails *ResponsesIncompleteDetails,
) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
usage := &ResponsesUsage{
InputTokens: state.InputTokens,
OutputTokens: state.OutputTokens,
TotalTokens: state.InputTokens + state.OutputTokens,
}
if state.CacheReadInputTokens > 0 {
usage.InputTokensDetails = &ResponsesInputTokensDetails{
CachedTokens: state.CacheReadInputTokens,
}
}
return ResponsesStreamEvent{
Type: "response.completed",
SequenceNumber: seq,
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: status,
Output: []ResponsesOutput{}, // Simplified; full output tracking would add complexity
Usage: usage,
IncompleteDetails: incompleteDetails,
},
}
}
func makeResponsesEvent(state *AnthropicEventToResponsesState, eventType string, template *ResponsesStreamEvent) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
evt := *template
evt.Type = eventType
evt.SequenceNumber = seq
return evt
}
func generateResponsesID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "resp_" + hex.EncodeToString(b)
}
func generateItemID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "item_" + hex.EncodeToString(b)
}
package apicompat
import (
"encoding/json"
"fmt"
"strings"
)
// ResponsesToAnthropicRequest converts a Responses API request into an
// Anthropic Messages request. This is the reverse of AnthropicToResponses and
// enables Anthropic platform groups to accept OpenAI Responses API requests
// by converting them to the native /v1/messages format before forwarding upstream.
func ResponsesToAnthropicRequest(req *ResponsesRequest) (*AnthropicRequest, error) {
system, messages, err := convertResponsesInputToAnthropic(req.Input)
if err != nil {
return nil, err
}
out := &AnthropicRequest{
Model: req.Model,
Messages: messages,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: req.Stream,
}
if len(system) > 0 {
out.System = system
}
// max_output_tokens → max_tokens
if req.MaxOutputTokens != nil && *req.MaxOutputTokens > 0 {
out.MaxTokens = *req.MaxOutputTokens
}
if out.MaxTokens == 0 {
// Anthropic requires max_tokens; default to a sensible value.
out.MaxTokens = 8192
}
// Convert tools
if len(req.Tools) > 0 {
out.Tools = convertResponsesToAnthropicTools(req.Tools)
}
// Convert tool_choice (reverse of convertAnthropicToolChoiceToResponses)
if len(req.ToolChoice) > 0 {
tc, err := convertResponsesToAnthropicToolChoice(req.ToolChoice)
if err != nil {
return nil, fmt.Errorf("convert tool_choice: %w", err)
}
out.ToolChoice = tc
}
// reasoning.effort → output_config.effort + thinking
if req.Reasoning != nil && req.Reasoning.Effort != "" {
effort := mapResponsesEffortToAnthropic(req.Reasoning.Effort)
out.OutputConfig = &AnthropicOutputConfig{Effort: effort}
// Enable thinking for non-low efforts
if effort != "low" {
out.Thinking = &AnthropicThinking{
Type: "enabled",
BudgetTokens: defaultThinkingBudget(effort),
}
}
}
return out, nil
}
// defaultThinkingBudget returns a sensible thinking budget based on effort level.
func defaultThinkingBudget(effort string) int {
switch effort {
case "low":
return 1024
case "medium":
return 4096
case "high":
return 10240
case "max":
return 32768
default:
return 10240
}
}
// mapResponsesEffortToAnthropic converts OpenAI Responses reasoning effort to
// Anthropic effort levels. Reverse of mapAnthropicEffortToResponses.
//
// low → low
// medium → medium
// high → high
// xhigh → max
func mapResponsesEffortToAnthropic(effort string) string {
if effort == "xhigh" {
return "max"
}
return effort // low→low, medium→medium, high→high, unknown→passthrough
}
// convertResponsesInputToAnthropic extracts system prompt and messages from
// a Responses API input array. Returns the system as raw JSON (for Anthropic's
// polymorphic system field) and a list of Anthropic messages.
func convertResponsesInputToAnthropic(inputRaw json.RawMessage) (json.RawMessage, []AnthropicMessage, error) {
// Try as plain string input.
var inputStr string
if err := json.Unmarshal(inputRaw, &inputStr); err == nil {
content, _ := json.Marshal(inputStr)
return nil, []AnthropicMessage{{Role: "user", Content: content}}, nil
}
var items []ResponsesInputItem
if err := json.Unmarshal(inputRaw, &items); err != nil {
return nil, nil, fmt.Errorf("parse responses input: %w", err)
}
var system json.RawMessage
var messages []AnthropicMessage
for _, item := range items {
switch {
case item.Role == "system":
// System prompt → Anthropic system field
text := extractTextFromContent(item.Content)
if text != "" {
system, _ = json.Marshal(text)
}
case item.Type == "function_call":
// function_call → assistant message with tool_use block
input := json.RawMessage("{}")
if item.Arguments != "" {
input = json.RawMessage(item.Arguments)
}
block := AnthropicContentBlock{
Type: "tool_use",
ID: fromResponsesCallIDToAnthropic(item.CallID),
Name: item.Name,
Input: input,
}
blockJSON, _ := json.Marshal([]AnthropicContentBlock{block})
messages = append(messages, AnthropicMessage{
Role: "assistant",
Content: blockJSON,
})
case item.Type == "function_call_output":
// function_call_output → user message with tool_result block
outputContent := item.Output
if outputContent == "" {
outputContent = "(empty)"
}
contentJSON, _ := json.Marshal(outputContent)
block := AnthropicContentBlock{
Type: "tool_result",
ToolUseID: fromResponsesCallIDToAnthropic(item.CallID),
Content: contentJSON,
}
blockJSON, _ := json.Marshal([]AnthropicContentBlock{block})
messages = append(messages, AnthropicMessage{
Role: "user",
Content: blockJSON,
})
case item.Role == "user":
content, err := convertResponsesUserToAnthropicContent(item.Content)
if err != nil {
return nil, nil, err
}
messages = append(messages, AnthropicMessage{
Role: "user",
Content: content,
})
case item.Role == "assistant":
content, err := convertResponsesAssistantToAnthropicContent(item.Content)
if err != nil {
return nil, nil, err
}
messages = append(messages, AnthropicMessage{
Role: "assistant",
Content: content,
})
default:
// Unknown role/type — attempt as user message
if item.Content != nil {
messages = append(messages, AnthropicMessage{
Role: "user",
Content: item.Content,
})
}
}
}
// Merge consecutive same-role messages (Anthropic requires alternating roles)
messages = mergeConsecutiveMessages(messages)
return system, messages, nil
}
// extractTextFromContent extracts text from a content field that may be a
// plain string or an array of content parts.
func extractTextFromContent(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s
}
var parts []ResponsesContentPart
if err := json.Unmarshal(raw, &parts); err == nil {
var texts []string
for _, p := range parts {
if (p.Type == "input_text" || p.Type == "output_text" || p.Type == "text") && p.Text != "" {
texts = append(texts, p.Text)
}
}
return strings.Join(texts, "\n\n")
}
return ""
}
// convertResponsesUserToAnthropicContent converts a Responses user message
// content field into Anthropic content blocks JSON.
func convertResponsesUserToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 {
return json.Marshal("") // empty string content
}
// Try plain string.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal(s)
}
// Array of content parts → Anthropic content blocks.
var parts []ResponsesContentPart
if err := json.Unmarshal(raw, &parts); err != nil {
// Pass through as-is if we can't parse
return raw, nil
}
var blocks []AnthropicContentBlock
for _, p := range parts {
switch p.Type {
case "input_text", "text":
if p.Text != "" {
blocks = append(blocks, AnthropicContentBlock{
Type: "text",
Text: p.Text,
})
}
case "input_image":
src := dataURIToAnthropicImageSource(p.ImageURL)
if src != nil {
blocks = append(blocks, AnthropicContentBlock{
Type: "image",
Source: src,
})
}
}
}
if len(blocks) == 0 {
return json.Marshal("")
}
return json.Marshal(blocks)
}
// convertResponsesAssistantToAnthropicContent converts a Responses assistant
// message content field into Anthropic content blocks JSON.
func convertResponsesAssistantToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 {
return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: ""}})
}
// Try plain string.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: s}})
}
// Array of content parts → Anthropic content blocks.
var parts []ResponsesContentPart
if err := json.Unmarshal(raw, &parts); err != nil {
return raw, nil
}
var blocks []AnthropicContentBlock
for _, p := range parts {
switch p.Type {
case "output_text", "text":
if p.Text != "" {
blocks = append(blocks, AnthropicContentBlock{
Type: "text",
Text: p.Text,
})
}
}
}
if len(blocks) == 0 {
blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""})
}
return json.Marshal(blocks)
}
// fromResponsesCallIDToAnthropic converts an OpenAI function call ID back to
// Anthropic format. Reverses toResponsesCallID.
func fromResponsesCallIDToAnthropic(id string) string {
// If it has our "fc_" prefix wrapping a known Anthropic prefix, strip it
if after, ok := strings.CutPrefix(id, "fc_"); ok {
if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") {
return after
}
}
// Generate a synthetic Anthropic tool ID
if !strings.HasPrefix(id, "toolu_") && !strings.HasPrefix(id, "call_") {
return "toolu_" + id
}
return id
}
// dataURIToAnthropicImageSource parses a data URI into an AnthropicImageSource.
func dataURIToAnthropicImageSource(dataURI string) *AnthropicImageSource {
if !strings.HasPrefix(dataURI, "data:") {
return nil
}
// Format: data:<media_type>;base64,<data>
rest := strings.TrimPrefix(dataURI, "data:")
semicolonIdx := strings.Index(rest, ";")
if semicolonIdx < 0 {
return nil
}
mediaType := rest[:semicolonIdx]
rest = rest[semicolonIdx+1:]
if !strings.HasPrefix(rest, "base64,") {
return nil
}
data := strings.TrimPrefix(rest, "base64,")
return &AnthropicImageSource{
Type: "base64",
MediaType: mediaType,
Data: data,
}
}
// mergeConsecutiveMessages merges consecutive messages with the same role
// because Anthropic requires alternating user/assistant turns.
func mergeConsecutiveMessages(messages []AnthropicMessage) []AnthropicMessage {
if len(messages) <= 1 {
return messages
}
var merged []AnthropicMessage
for _, msg := range messages {
if len(merged) == 0 || merged[len(merged)-1].Role != msg.Role {
merged = append(merged, msg)
continue
}
// Same role — merge content arrays
last := &merged[len(merged)-1]
lastBlocks := parseContentBlocks(last.Content)
newBlocks := parseContentBlocks(msg.Content)
combined := append(lastBlocks, newBlocks...)
last.Content, _ = json.Marshal(combined)
}
return merged
}
// parseContentBlocks attempts to parse content as []AnthropicContentBlock.
// If it's a string, wraps it in a text block.
func parseContentBlocks(raw json.RawMessage) []AnthropicContentBlock {
var blocks []AnthropicContentBlock
if err := json.Unmarshal(raw, &blocks); err == nil {
return blocks
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return []AnthropicContentBlock{{Type: "text", Text: s}}
}
return nil
}
// convertResponsesToAnthropicTools maps Responses API tools to Anthropic format.
// Reverse of convertAnthropicToolsToResponses.
func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
var out []AnthropicTool
for _, t := range tools {
switch t.Type {
case "web_search":
out = append(out, AnthropicTool{
Type: "web_search_20250305",
Name: "web_search",
})
case "function":
out = append(out, AnthropicTool{
Name: t.Name,
Description: t.Description,
InputSchema: normalizeAnthropicInputSchema(t.Parameters),
})
default:
// Pass through unknown tool types
out = append(out, AnthropicTool{
Type: t.Type,
Name: t.Name,
Description: t.Description,
InputSchema: t.Parameters,
})
}
}
return out
}
// normalizeAnthropicInputSchema ensures the input_schema has a "type" field.
func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
if len(schema) == 0 || string(schema) == "null" {
return json.RawMessage(`{"type":"object","properties":{}}`)
}
return schema
}
// convertResponsesToAnthropicToolChoice maps Responses tool_choice to Anthropic format.
// Reverse of convertAnthropicToolChoiceToResponses.
//
// "auto" → {"type":"auto"}
// "required" → {"type":"any"}
// "none" → {"type":"none"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try as string first
var s string
if err := json.Unmarshal(raw, &s); err == nil {
switch s {
case "auto":
return json.Marshal(map[string]string{"type": "auto"})
case "required":
return json.Marshal(map[string]string{"type": "any"})
case "none":
return json.Marshal(map[string]string{"type": "none"})
default:
return raw, nil
}
}
// Try as object with type=function
var tc struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
} `json:"function"`
}
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
return json.Marshal(map[string]string{
"type": "tool",
"name": tc.Function.Name,
})
}
// Pass through unknown
return raw, nil
}
...@@ -270,6 +270,7 @@ type OpenAIAuthClaims struct { ...@@ -270,6 +270,7 @@ type OpenAIAuthClaims struct {
ChatGPTUserID string `json:"chatgpt_user_id"` ChatGPTUserID string `json:"chatgpt_user_id"`
ChatGPTPlanType string `json:"chatgpt_plan_type"` ChatGPTPlanType string `json:"chatgpt_plan_type"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
POID string `json:"poid"` // organization ID in access_token JWT
Organizations []OrganizationClaim `json:"organizations"` Organizations []OrganizationClaim `json:"organizations"`
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment