Commit 31660c4c authored by Ethan0x0000's avatar Ethan0x0000
Browse files

feat(handler): add Responses/ChatCompletions handlers on GatewayHandler

New HTTP handlers for Anthropic platform groups accepting OpenAI-format
endpoints:

- GatewayHandler.Responses: /v1/responses for non-OpenAI groups
- GatewayHandler.ChatCompletions: /v1/chat/completions for non-OpenAI groups

Both handlers include:
- Claude Code only restriction (403 reject when claude_code_only enabled,
  since these endpoints are never Claude Code clients)
- Full auth → billing → user/account concurrency → failover loop
- Ops error/endpoint context propagation
- Async usage recording via worker pool

Error responses use each endpoint's native format (Responses API format
for /v1/responses, CC format for /v1/chat/completions).
parent 4321adab
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups.
// POST /v1/chat/completions
// This converts Chat Completions requests to Anthropic format (via Responses format chain),
// forwards to Anthropic upstream, and converts responses back to Chat Completions format.
func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.chatCompletionsErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleCCFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.cc.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format.
func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// handleCCFailoverExhausted writes a failover-exhausted error in CC format.
func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// Responses handles OpenAI Responses API endpoint for Anthropic platform groups.
// POST /v1/responses
// This converts Responses API requests to Anthropic format, forwards to Anthropic
// upstream, and converts responses back to Responses format.
func (h *GatewayHandler) Responses(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// Validate JSON
if !gjson.ValidBytes(body) {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream using gjson (like OpenAI handler)
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected.
// The existing service-layer checkClaudeCodeRestriction handles degradation
// to fallback groups when the Forward path calls SelectAccountForModelWithExclusions.
// Here we just reject at handler level since /v1/responses clients can't be Claude Code.
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.responsesErrorResponse(c, http.StatusForbidden, "permission_error",
"This group is restricted to Claude Code clients (/v1/messages only)")
return
}
// Error passthrough binding
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
// 1. Acquire user concurrency slot
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.responsesErrorResponse(c, status, code, message)
return
}
// Parse request for session hash
parsedReq, _ := service.ParseGatewayRequest(body, "responses")
if parsedReq == nil {
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
}
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
continue
case FailoverCanceled:
return
default:
if fs.LastFailoverErr != nil {
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
} else {
h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
}
return
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
// 4. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// Can't failover if streaming content already sent
if c.Writer.Size() != writerSizeBeforeForward {
h.handleResponsesFailoverExhausted(c, failoverErr, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
return
case FailoverCanceled:
return
}
}
h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.responses.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
// 6. Record usage
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
}
})
return
}
}
// responsesErrorResponse writes an error in OpenAI Responses API format.
func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": code,
"message": message,
},
})
}
// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format.
func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
if streamStarted {
return // Can't write error after stream started
}
statusCode := http.StatusBadGateway
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment