Unverified Commit 17c3cb24 authored by 程序猿MT's avatar 程序猿MT Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents a413fa3b 7af1bdbf
...@@ -143,3 +143,56 @@ jobs: ...@@ -143,3 +143,56 @@ jobs:
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
short-description: "Sub2API - AI API Gateway Platform" short-description: "Sub2API - AI API Gateway Platform"
readme-filepath: ./deploy/DOCKER.md readme-filepath: ./deploy/DOCKER.md
# Send Telegram notification
- name: Send Telegram Notification
if: ${{ secrets.TELEGRAM_BOT_TOKEN != '' && secrets.TELEGRAM_CHAT_ID != '' }}
env:
TELEGRAM_BOT_TOKEN: ${{ secrets.TELEGRAM_BOT_TOKEN }}
TELEGRAM_CHAT_ID: ${{ secrets.TELEGRAM_CHAT_ID }}
continue-on-error: true
run: |
TAG_NAME=${GITHUB_REF#refs/tags/}
VERSION=${TAG_NAME#v}
REPO="${{ github.repository }}"
DOCKER_IMAGE="${{ secrets.DOCKERHUB_USERNAME }}/sub2api"
# 获取 tag message 内容
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
# 限制消息长度(Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
if [ ${#TAG_MESSAGE} -gt 3500 ]; then
TAG_MESSAGE="${TAG_MESSAGE:0:3500}..."
fi
# 构建消息内容
MESSAGE="🚀 *Sub2API 新版本发布!*"$'\n'$'\n'
MESSAGE+="📦 版本号: \`${VERSION}\`"$'\n'$'\n'
# 添加更新内容
if [ -n "$TAG_MESSAGE" ]; then
MESSAGE+="${TAG_MESSAGE}"$'\n'$'\n'
fi
MESSAGE+="🐳 *Docker 部署:*"$'\n'
MESSAGE+="\`\`\`bash"$'\n'
MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n'
MESSAGE+="docker pull ${DOCKER_IMAGE}:latest"$'\n'
MESSAGE+="\`\`\`"$'\n'$'\n'
MESSAGE+="🔗 *相关链接:*"$'\n'
MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n'
MESSAGE+="• [Docker Hub](https://hub.docker.com/r/${DOCKER_IMAGE})"$'\n'$'\n'
MESSAGE+="#Sub2API #Release #${TAG_NAME//./_}"
# 发送消息
curl -s -X POST "https://api.telegram.org/bot${TELEGRAM_BOT_TOKEN}/sendMessage" \
-H "Content-Type: application/json" \
-d "$(jq -n \
--arg chat_id "${TELEGRAM_CHAT_ID}" \
--arg text "${MESSAGE}" \
'{
chat_id: $chat_id,
text: $text,
parse_mode: "Markdown",
disable_web_page_preview: true
}')"
...@@ -3,6 +3,7 @@ package handler ...@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
...@@ -127,66 +128,158 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -127,66 +128,158 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
platform = apiKey.Group.Platform platform = apiKey.Group.Platform
} }
// 选择支持该模型的账号
var account *service.Account
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) const maxAccountSwitches = 3
} else { switchCount := 0
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) failedAccountIDs := make(map[int64]struct{})
} lastFailoverStatus := 0
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) for {
return account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
} if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// 转发请求
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// 错误响应已在Forward中处理,这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 检查预热请求拦截(在账号选择后、转发前检查) // 异步记录使用量(subscription已在函数开头获取)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { go func(result *service.ForwardResult, usedAccount *service.Account) {
if req.Stream { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
sendMockWarmupStream(c, req.Model) defer cancel()
} else { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
sendMockWarmupResponse(c, req.Model) Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
} }
return
} }
// 3. 获取账号并发槽位 const maxAccountSwitches = 3
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) switchCount := 0
if err != nil { failedAccountIDs := make(map[int64]struct{})
log.Printf("Account concurrency acquire failed: %v", err) lastFailoverStatus := 0
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// 转发请求 for {
var result *service.ForwardResult // 选择支持该模型的账号
if platform == service.PlatformGemini { account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) if err != nil {
} else { if len(failedAccountIDs) == 0 {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
} return
if err != nil { }
// 错误响应已在Forward中处理,这里只记录日志 h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
log.Printf("Forward request failed: %v", err) return
return }
}
// 异步记录使用量(subscription已在函数开头获取) // 检查预热请求拦截(在账号选择后、转发前检查)
go func() { if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) if req.Stream {
defer cancel() sendMockWarmupStream(c, req.Model)
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ } else {
Result: result, sendMockWarmupResponse(c, req.Model)
ApiKey: apiKey, }
User: apiKey.User, return
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
} }
}()
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// 错误响应已在Forward中处理,这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量(subscription已在函数开头获取)
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
} }
// Models handles listing available models // Models handles listing available models
...@@ -314,6 +407,28 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT ...@@ -314,6 +407,28 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
// handleStreamingAwareError handles errors that may occur after streaming has started // handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted { if streamStarted {
......
...@@ -2,6 +2,7 @@ package handler ...@@ -2,6 +2,7 @@ package handler
import ( import (
"context" "context"
"errors"
"io" "io"
"log" "log"
"net/http" "net/http"
...@@ -158,44 +159,69 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -158,44 +159,69 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body) sessionHash := h.gatewayService.GenerateSessionHash(body)
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName) const maxAccountSwitches = 3
if err != nil { switchCount := 0
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) failedAccountIDs := make(map[int64]struct{})
return lastFailoverStatus := 0
}
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
// 4) account concurrency slot // 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil { if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error()) googleError(c, http.StatusTooManyRequests, err.Error())
return return
} }
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// 5) forward (writes response to client) // 5) forward (writes response to client)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
if err != nil { if accountReleaseFunc != nil {
// ForwardNative already wrote the response accountReleaseFunc()
log.Printf("Gemini native forward failed: %v", err) }
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
return
}
// 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return return
} }
// 6) record usage async
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
} }
func parseGeminiModelAction(rest string) (model string, action string, err error) { func parseGeminiModelAction(rest string) (model string, action string, err error) {
...@@ -217,6 +243,28 @@ func parseGeminiModelAction(rest string) (model string, action string, err error ...@@ -217,6 +243,28 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
return "", "", &pathParseError{"invalid model action path"} return "", "", &pathParseError{"invalid model action path"}
} }
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
status, message := mapGeminiUpstreamError(statusCode)
googleError(c, status, message)
}
func mapGeminiUpstreamError(statusCode int) (int, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "Upstream request failed"
}
}
type pathParseError struct{ msg string } type pathParseError struct{ msg string }
func (e *pathParseError) Error() string { return e.msg } func (e *pathParseError) Error() string { return e.msg }
......
...@@ -3,6 +3,7 @@ package handler ...@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
...@@ -127,49 +128,74 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -127,49 +128,74 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (from header for OpenAI) // Generate session hash (from header for OpenAI)
sessionHash := h.gatewayService.GenerateSessionHash(c) sessionHash := h.gatewayService.GenerateSessionHash(c)
// Select account supporting the requested model const maxAccountSwitches = 3
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) switchCount := 0
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel) failedAccountIDs := make(map[int64]struct{})
if err != nil { lastFailoverStatus := 0
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot for {
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) // Select account supporting the requested model
if err != nil { log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
log.Printf("Account concurrency acquire failed: %v", err) account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
h.handleConcurrencyError(c, err, "account", streamStarted) if err != nil {
return log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
} if len(failedAccountIDs) == 0 {
if accountReleaseFunc != nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
defer accountReleaseFunc() return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// Forward request // 3. Acquire account concurrency slot
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
// Error response already handled in Forward, just log log.Printf("Account concurrency acquire failed: %v", err)
log.Printf("Forward request failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
// Async record usage // Forward request
go func() { result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) if accountReleaseFunc != nil {
defer cancel() accountReleaseFunc()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ }
Result: result, if err != nil {
ApiKey: apiKey, var failoverErr *service.UpstreamFailoverError
User: apiKey.User, if errors.As(err, &failoverErr) {
Account: account, failedAccountIDs[account.ID] = struct{}{}
Subscription: subscription, if switchCount >= maxAccountSwitches {
}); err != nil { lastFailoverStatus = failoverErr.StatusCode
log.Printf("Record usage failed: %v", err) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// Error response already handled in Forward, just log
log.Printf("Forward request failed: %v", err)
return
} }
}()
// Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
} }
// handleConcurrencyError handles concurrency-related errors with proper 429 response // handleConcurrencyError handles concurrency-related errors with proper 429 response
...@@ -178,6 +204,28 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, ...@@ -178,6 +204,28 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
// handleStreamingAwareError handles errors that may occur after streaming has started // handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted { if streamStarted {
......
...@@ -199,16 +199,20 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod ...@@ -199,16 +199,20 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL) client := s.clientFactory(proxyURL)
formData := url.Values{} // 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
formData.Set("grant_type", "refresh_token") // Anthropic OAuth API 期望 JSON 格式的请求体
formData.Set("refresh_token", refreshToken) reqBody := map[string]any{
formData.Set("client_id", oauth.ClientID) "grant_type": "refresh_token",
"refresh_token": refreshToken,
"client_id": oauth.ClientID,
}
var tokenResp oauth.TokenResponse var tokenResp oauth.TokenResponse
resp, err := client.R(). resp, err := client.R().
SetContext(ctx). SetContext(ctx).
SetFormDataFromValues(formData). SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp). SetSuccessResult(&tokenResp).
Post(s.tokenURL) Post(s.tokenURL)
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings" "strings"
"testing" "testing"
...@@ -34,7 +33,6 @@ type requestCapture struct { ...@@ -34,7 +33,6 @@ type requestCapture struct {
method string method string
cookies []*http.Cookie cookies []*http.Cookie
body []byte body []byte
formValues url.Values
bodyJSON map[string]any bodyJSON map[string]any
contentType string contentType string
} }
...@@ -282,24 +280,53 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { ...@@ -282,24 +280,53 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
validate func(captured requestCapture) validate func(captured requestCapture)
}{ }{
{ {
name: "sends_form", name: "sends_json_format",
handler: func(w http.ResponseWriter, r *http.Request) { handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600}) _ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "new_access_token",
TokenType: "bearer",
ExpiresIn: 28800,
RefreshToken: "new_refresh_token",
Scope: "user:profile user:inference",
})
},
wantResp: &oauth.TokenResponse{
AccessToken: "new_access_token",
RefreshToken: "new_refresh_token",
}, },
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
validate: func(captured requestCapture) { validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST") require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type")) // 验证使用 JSON 格式(不是 form 格式)
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token")) require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id")) "expected JSON content-type, got: %s", captured.contentType)
// 验证 JSON body 内容
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
},
},
{
name: "returns_new_refresh_token",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 28800,
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
})
},
wantResp: &oauth.TokenResponse{
AccessToken: "at",
RefreshToken: "rotated_rt",
}, },
}, },
{ {
name: "non_200_returns_error", name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) { handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("unauthorized")) _, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
}, },
wantErr: true, wantErr: true,
}, },
...@@ -311,8 +338,9 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { ...@@ -311,8 +338,9 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body) captured.body, _ = io.ReadAll(r.Body)
captured.formValues, _ = url.ParseQuery(string(captured.body)) _ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r) tt.handler(w, r)
})) }))
defer s.srv.Close() defer s.srv.Close()
...@@ -331,6 +359,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { ...@@ -331,6 +359,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
require.NoError(s.T(), err) require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken) require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
if tt.validate != nil { if tt.validate != nil {
tt.validate(captured) tt.validate(captured)
} }
......
...@@ -925,8 +925,38 @@ func (r *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, ...@@ -925,8 +925,38 @@ func (r *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64,
func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
logs := r.userLogs[filters.UserID] logs := r.userLogs[filters.UserID]
total := int64(len(logs))
out := paginateLogs(logs, params) // Apply filters
var filtered []service.UsageLog
for _, log := range logs {
// Apply ApiKeyID filter
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID {
continue
}
// Apply Model filter
if filters.Model != "" && log.Model != filters.Model {
continue
}
// Apply Stream filter
if filters.Stream != nil && log.Stream != *filters.Stream {
continue
}
// Apply BillingType filter
if filters.BillingType != nil && log.BillingType != *filters.BillingType {
continue
}
// Apply time range filters
if filters.StartTime != nil && log.CreatedAt.Before(*filters.StartTime) {
continue
}
if filters.EndTime != nil && log.CreatedAt.After(*filters.EndTime) {
continue
}
filtered = append(filtered, log)
}
total := int64(len(filtered))
out := paginateLogs(filtered, params)
return out, paginationResult(total, params), nil return out, paginationResult(total, params), nil
} }
......
package service package service
import "time" import (
"strconv"
"time"
)
type Account struct { type Account struct {
ID int64 ID int64
...@@ -82,12 +85,25 @@ func (a *Account) GetCredential(key string) string { ...@@ -82,12 +85,25 @@ func (a *Account) GetCredential(key string) string {
if a.Credentials == nil { if a.Credentials == nil {
return "" return ""
} }
if v, ok := a.Credentials[key]; ok { v, ok := a.Credentials[key]
if s, ok := v.(string); ok { if !ok || v == nil {
return s return ""
} }
// 支持多种类型(兼容历史数据中 expires_at 等字段可能是数字或字符串)
switch val := v.(type) {
case string:
return val
case float64:
// JSON 解析后数字默认为 float64
return strconv.FormatInt(int64(val), 10)
case int64:
return strconv.FormatInt(val, 10)
case int:
return strconv.Itoa(val)
default:
return ""
} }
return ""
} }
func (a *Account) GetModelMapping() map[string]string { func (a *Account) GetModelMapping() map[string]string {
......
...@@ -208,20 +208,23 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount ...@@ -208,20 +208,23 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
account.Status = *req.Status account.Status = *req.Status
} }
if err := s.accountRepo.Update(ctx, account); err != nil { // 先验证分组是否存在(在任何写操作之前)
return nil, fmt.Errorf("update account: %w", err)
}
// 更新分组绑定
if req.GroupIDs != nil { if req.GroupIDs != nil {
// 验证分组是否存在
for _, groupID := range *req.GroupIDs { for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID) _, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
} }
}
// 执行更新
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, fmt.Errorf("update account: %w", err)
}
// 绑定分组
if req.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
return nil, fmt.Errorf("bind groups: %w", err) return nil, fmt.Errorf("bind groups: %w", err)
} }
......
...@@ -652,11 +652,20 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -652,11 +652,20 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Status = input.Status account.Status = input.Status
} }
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
for _, groupID := range *input.GroupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
}
}
if err := s.accountRepo.Update(ctx, account); err != nil { if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err return nil, err
} }
// 更新分组绑定 // 绑定分组
if input.GroupIDs != nil { if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil {
return nil, err return nil, err
......
...@@ -81,6 +81,15 @@ type ForwardResult struct { ...@@ -81,6 +81,15 @@ type ForwardResult struct {
FirstTokenMs *int // 首字时间(流式请求) FirstTokenMs *int // 首字时间(流式请求)
} }
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
}
func (e *UpstreamFailoverError) Error() string {
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
}
// GatewayService handles API gateway operations // GatewayService handles API gateway operations
type GatewayService struct { type GatewayService struct {
accountRepo AccountRepository accountRepo AccountRepository
...@@ -274,19 +283,26 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess ...@@ -274,19 +283,26 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) // SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID) if _, excluded := excludedIDs[accountID]; !excluded {
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中 account, err := s.accountRepo.GetByID(ctx, accountID)
// 同时检查模型支持 // 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { // 同时检查模型支持
// 续期粘性会话 if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { // 续期粘性会话
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
} }
return account, nil
} }
} }
} }
...@@ -307,6 +323,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int ...@@ -307,6 +323,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 检查模型支持 // 检查模型支持
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
...@@ -394,6 +413,16 @@ func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode i ...@@ -394,6 +413,16 @@ func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode i
return !account.ShouldHandleErrorCode(statusCode) return !account.ShouldHandleErrorCode(statusCode)
} }
// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover.
func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
// Forward 转发请求到Claude API // Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
...@@ -478,9 +507,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -478,9 +507,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理重试耗尽的情况 // 处理重试耗尽的情况
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleRetryExhaustedError(ctx, resp, c, account) return s.handleRetryExhaustedError(ctx, resp, c, account)
} }
// 处理可切换账号的错误
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 处理错误响应(不可重试的错误) // 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
...@@ -692,10 +731,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ...@@ -692,10 +731,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
} }
// handleRetryExhaustedError 处理重试耗尽后的错误 func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
// OAuth 403:标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode statusCode := resp.StatusCode
...@@ -707,6 +743,18 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht ...@@ -707,6 +743,18 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
// API Key 未配置错误码:不标记账号状态 // API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries) log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
} }
}
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403:标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
// 返回统一的重试耗尽错误响应 // 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{ c.JSON(http.StatusBadGateway, gin.H{
...@@ -717,7 +765,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht ...@@ -717,7 +765,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
}, },
}) })
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", statusCode) return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
} }
// streamingResult 流式响应结果 // streamingResult 流式响应结果
......
...@@ -62,14 +62,20 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider { ...@@ -62,14 +62,20 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
} }
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
cacheKey := "gemini:" + sessionHash cacheKey := "gemini:" + sessionHash
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID) if _, excluded := excludedIDs[accountID]; !excluded {
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { account, err := s.accountRepo.GetByID(ctx, accountID)
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
return account, nil _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
} }
} }
} }
...@@ -88,6 +94,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, ...@@ -88,6 +94,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
} }
...@@ -425,6 +434,9 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -425,6 +434,9 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody) return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
} }
...@@ -724,6 +736,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -724,6 +736,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil }, nil
} }
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
respBody = unwrapIfNeeded(isOAuth, respBody) respBody = unwrapIfNeeded(isOAuth, respBody)
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
if contentType == "" { if contentType == "" {
...@@ -795,6 +811,15 @@ func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Ac ...@@ -795,6 +811,15 @@ func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Ac
} }
} }
func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func sleepGeminiBackoff(attempt int) { func sleepGeminiBackoff(attempt int) {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1)) delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay { if delay > geminiRetryMaxDelay {
......
...@@ -129,15 +129,22 @@ func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64 ...@@ -129,15 +129,22 @@ func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64
// SelectAccountForModel selects an account supporting the requested model // SelectAccountForModel selects an account supporting the requested model
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. Check sticky session // 1. Check sticky session
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID) if _, excluded := excludedIDs[accountID]; !excluded {
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { account, err := s.accountRepo.GetByID(ctx, accountID)
// Refresh sticky session TTL if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) // Refresh sticky session TTL
return account, nil _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil
}
} }
} }
} }
...@@ -158,6 +165,9 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI ...@@ -158,6 +165,9 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Check model support // Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
...@@ -221,6 +231,20 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco ...@@ -221,6 +231,20 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
} }
} }
func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// Forward forwards request to OpenAI API // Forward forwards request to OpenAI API
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now() startTime := time.Now()
...@@ -288,6 +312,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -288,6 +312,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Handle error response // Handle error response
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
} }
......
...@@ -280,10 +280,12 @@ ...@@ -280,10 +280,12 @@
import { ref, watch, nextTick } from 'vue' import { ref, watch, nextTick } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import Modal from '@/components/common/Modal.vue' import Modal from '@/components/common/Modal.vue'
import { useClipboard } from '@/composables/useClipboard'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import type { Account, ClaudeModel } from '@/types' import type { Account, ClaudeModel } from '@/types'
const { t } = useI18n() const { t } = useI18n()
const { copyToClipboard } = useClipboard()
interface OutputLine { interface OutputLine {
text: string text: string
...@@ -501,6 +503,6 @@ const handleEvent = (event: { ...@@ -501,6 +503,6 @@ const handleEvent = (event: {
const copyOutput = () => { const copyOutput = () => {
const text = outputLines.value.map((l) => l.text).join('\n') const text = outputLines.value.map((l) => l.text).join('\n')
navigator.clipboard.writeText(text) copyToClipboard(text, t('admin.accounts.outputCopied'))
} }
</script> </script>
...@@ -119,7 +119,7 @@ ...@@ -119,7 +119,7 @@
import { ref, computed, h, watch, type Component } from 'vue' import { ref, computed, h, watch, type Component } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import Modal from '@/components/common/Modal.vue' import Modal from '@/components/common/Modal.vue'
import { useAppStore } from '@/stores/app' import { useClipboard } from '@/composables/useClipboard'
import type { GroupPlatform } from '@/types' import type { GroupPlatform } from '@/types'
interface Props { interface Props {
...@@ -150,7 +150,7 @@ const props = defineProps<Props>() ...@@ -150,7 +150,7 @@ const props = defineProps<Props>()
const emit = defineEmits<Emits>() const emit = defineEmits<Emits>()
const { t } = useI18n() const { t } = useI18n()
const appStore = useAppStore() const { copyToClipboard: clipboardCopy } = useClipboard()
const copiedIndex = ref<number | null>(null) const copiedIndex = ref<number | null>(null)
const activeTab = ref<string>('unix') const activeTab = ref<string>('unix')
...@@ -340,14 +340,12 @@ ${key('requires_openai_auth')} ${operator('=')} ${keyword('true')}` ...@@ -340,14 +340,12 @@ ${key('requires_openai_auth')} ${operator('=')} ${keyword('true')}`
} }
const copyContent = async (content: string, index: number) => { const copyContent = async (content: string, index: number) => {
try { const success = await clipboardCopy(content, t('keys.copied'))
await navigator.clipboard.writeText(content) if (success) {
copiedIndex.value = index copiedIndex.value = index
setTimeout(() => { setTimeout(() => {
copiedIndex.value = null copiedIndex.value = null
}, 2000) }, 2000)
} catch (error) {
appStore.showError(t('common.copyFailed'))
} }
} }
</script> </script>
import { ref } from 'vue' import { ref } from 'vue'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
/**
* 检测是否支持 Clipboard API(需要安全上下文:HTTPS/localhost)
*/
function isClipboardSupported(): boolean {
return !!(navigator.clipboard && window.isSecureContext)
}
/**
* 降级方案:使用 textarea + execCommand
* 使用 textarea 而非 input,以正确处理多行文本
*/
function fallbackCopy(text: string): boolean {
const textarea = document.createElement('textarea')
textarea.value = text
textarea.style.cssText = 'position:fixed;left:-9999px;top:-9999px'
document.body.appendChild(textarea)
textarea.select()
try {
return document.execCommand('copy')
} finally {
document.body.removeChild(textarea)
}
}
export function useClipboard() { export function useClipboard() {
const appStore = useAppStore() const appStore = useAppStore()
const copied = ref(false) const copied = ref(false)
const copyToClipboard = async (text: string, successMessage = 'Copied to clipboard') => { const copyToClipboard = async (
text: string,
successMessage = 'Copied to clipboard'
): Promise<boolean> => {
if (!text) return false if (!text) return false
try { let success = false
await navigator.clipboard.writeText(text)
copied.value = true if (isClipboardSupported()) {
appStore.showSuccess(successMessage) try {
setTimeout(() => { await navigator.clipboard.writeText(text)
copied.value = false success = true
}, 2000) } catch {
return true success = fallbackCopy(text)
} catch { }
// Fallback for older browsers } else {
const input = document.createElement('input') success = fallbackCopy(text)
input.value = text }
document.body.appendChild(input)
input.select() if (success) {
document.execCommand('copy')
document.body.removeChild(input)
copied.value = true copied.value = true
appStore.showSuccess(successMessage) appStore.showSuccess(successMessage)
setTimeout(() => { setTimeout(() => {
copied.value = false copied.value = false
}, 2000) }, 2000)
return true } else {
appStore.showError('Copy failed')
} }
}
return { return success
copied,
copyToClipboard
} }
return { copied, copyToClipboard }
} }
...@@ -418,6 +418,7 @@ ...@@ -418,6 +418,7 @@
import { ref, reactive, computed, onMounted } from 'vue' import { ref, reactive, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { useClipboard } from '@/composables/useClipboard'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import { formatDateTime } from '@/utils/format' import { formatDateTime } from '@/utils/format'
import type { RedeemCode, RedeemCodeType, Group } from '@/types' import type { RedeemCode, RedeemCodeType, Group } from '@/types'
...@@ -431,6 +432,7 @@ import Select from '@/components/common/Select.vue' ...@@ -431,6 +432,7 @@ import Select from '@/components/common/Select.vue'
const { t } = useI18n() const { t } = useI18n()
const appStore = useAppStore() const appStore = useAppStore()
const { copyToClipboard: clipboardCopy } = useClipboard()
const showGenerateDialog = ref(false) const showGenerateDialog = ref(false)
const showResultDialog = ref(false) const showResultDialog = ref(false)
...@@ -618,15 +620,12 @@ const handleGenerateCodes = async () => { ...@@ -618,15 +620,12 @@ const handleGenerateCodes = async () => {
} }
const copyToClipboard = async (text: string) => { const copyToClipboard = async (text: string) => {
try { const success = await clipboardCopy(text, t('admin.redeem.copied'))
await navigator.clipboard.writeText(text) if (success) {
copiedCode.value = text copiedCode.value = text
setTimeout(() => { setTimeout(() => {
copiedCode.value = null copiedCode.value = null
}, 2000) }, 2000)
} catch (error) {
appStore.showError(t('admin.redeem.failedToCopy'))
console.error('Error copying to clipboard:', error)
} }
} }
......
...@@ -1173,6 +1173,7 @@ ...@@ -1173,6 +1173,7 @@
import { ref, reactive, computed, onMounted } from 'vue' import { ref, reactive, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { useClipboard } from '@/composables/useClipboard'
import { formatDateTime } from '@/utils/format' import { formatDateTime } from '@/utils/format'
const { t } = useI18n() const { t } = useI18n()
...@@ -1191,6 +1192,7 @@ import Select from '@/components/common/Select.vue' ...@@ -1191,6 +1192,7 @@ import Select from '@/components/common/Select.vue'
import GroupBadge from '@/components/common/GroupBadge.vue' import GroupBadge from '@/components/common/GroupBadge.vue'
const appStore = useAppStore() const appStore = useAppStore()
const { copyToClipboard: clipboardCopy } = useClipboard()
const columns = computed<Column[]>(() => [ const columns = computed<Column[]>(() => [
{ key: 'email', label: t('admin.users.columns.user'), sortable: true }, { key: 'email', label: t('admin.users.columns.user'), sortable: true },
...@@ -1312,27 +1314,23 @@ const generateEditPassword = () => { ...@@ -1312,27 +1314,23 @@ const generateEditPassword = () => {
const copyPassword = async () => { const copyPassword = async () => {
if (!createForm.password) return if (!createForm.password) return
try { const success = await clipboardCopy(createForm.password, t('admin.users.passwordCopied'))
await navigator.clipboard.writeText(createForm.password) if (success) {
passwordCopied.value = true passwordCopied.value = true
setTimeout(() => { setTimeout(() => {
passwordCopied.value = false passwordCopied.value = false
}, 2000) }, 2000)
} catch (error) {
appStore.showError(t('common.copyFailed'))
} }
} }
const copyEditPassword = async () => { const copyEditPassword = async () => {
if (!editForm.password) return if (!editForm.password) return
try { const success = await clipboardCopy(editForm.password, t('admin.users.passwordCopied'))
await navigator.clipboard.writeText(editForm.password) if (success) {
editPasswordCopied.value = true editPasswordCopied.value = true
setTimeout(() => { setTimeout(() => {
editPasswordCopied.value = false editPasswordCopied.value = false
}, 2000) }, 2000)
} catch (error) {
appStore.showError(t('common.copyFailed'))
} }
} }
......
...@@ -493,6 +493,7 @@ ...@@ -493,6 +493,7 @@
import { ref, computed, onMounted, onUnmounted, type ComponentPublicInstance } from 'vue' import { ref, computed, onMounted, onUnmounted, type ComponentPublicInstance } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { useClipboard } from '@/composables/useClipboard'
const { t } = useI18n() const { t } = useI18n()
import { keysAPI, authAPI, usageAPI, userGroupsAPI } from '@/api' import { keysAPI, authAPI, usageAPI, userGroupsAPI } from '@/api'
...@@ -520,6 +521,7 @@ interface GroupOption { ...@@ -520,6 +521,7 @@ interface GroupOption {
} }
const appStore = useAppStore() const appStore = useAppStore()
const { copyToClipboard: clipboardCopy } = useClipboard()
const columns = computed<Column[]>(() => [ const columns = computed<Column[]>(() => [
{ key: 'name', label: t('common.name'), sortable: true }, { key: 'name', label: t('common.name'), sortable: true },
...@@ -616,14 +618,12 @@ const maskKey = (key: string): string => { ...@@ -616,14 +618,12 @@ const maskKey = (key: string): string => {
} }
const copyToClipboard = async (text: string, keyId: number) => { const copyToClipboard = async (text: string, keyId: number) => {
try { const success = await clipboardCopy(text, t('keys.copied'))
await navigator.clipboard.writeText(text) if (success) {
copiedKeyId.value = keyId copiedKeyId.value = keyId
setTimeout(() => { setTimeout(() => {
copiedKeyId.value = null copiedKeyId.value = null
}, 2000) }, 2000)
} catch (error) {
appStore.showError(t('common.copyFailed'))
} }
} }
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment