Commit d8aff3a7 authored by ius's avatar ius
Browse files

Merge origin/main into fix/account-extra-scheduler-pressure-20260311

parents 2b30e3b6 c0110cb5
......@@ -84,10 +84,12 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
// Gemini 2.5 白名单
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
// Gemini 3 白名单
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
......
......@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
t.Parallel()
cases := map[string]string{
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
"gemini-3-pro-image": "gemini-3.1-flash-image",
......
......@@ -628,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
// TestAccountRequest represents the request body for testing an account
type TestAccountRequest struct {
ModelID string `json:"model_id"`
Prompt string `json:"prompt"`
}
type SyncFromCRSRequest struct {
......@@ -658,7 +659,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
// Error already sent via SSE, just log
return
}
......
......@@ -249,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"trend": trend,
......@@ -321,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"models": stats,
......@@ -391,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"groups": stats,
......@@ -416,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
limit = 5
}
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"trend": trend,
......@@ -442,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
limit = 12
}
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"trend": trend,
......
package admin
import (
"context"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dashboardUsageRepoCacheProbe struct {
service.UsageLogRepository
trendCalls atomic.Int32
usersTrendCalls atomic.Int32
}
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
r.trendCalls.Add(1)
return []usagestats.TrendDataPoint{{
Date: "2026-03-11",
Requests: 1,
TotalTokens: 2,
Cost: 3,
ActualCost: 4,
}}, nil
}
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
limit int,
) ([]usagestats.UserUsageTrendPoint, error) {
r.usersTrendCalls.Add(1)
return []usagestats.UserUsageTrendPoint{{
Date: "2026-03-11",
UserID: 1,
Email: "cache@test.dev",
Requests: 2,
Tokens: 20,
Cost: 2,
ActualCost: 1,
}}, nil
}
func resetDashboardReadCachesForTest() {
dashboardTrendCache = newSnapshotCache(30 * time.Second)
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
}
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
t.Cleanup(resetDashboardReadCachesForTest)
resetDashboardReadCachesForTest()
gin.SetMode(gin.TestMode)
repo := &dashboardUsageRepoCacheProbe{}
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
require.Equal(t, int32(1), repo.trendCalls.Load())
}
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
t.Cleanup(resetDashboardReadCachesForTest)
resetDashboardReadCachesForTest()
gin.SetMode(gin.TestMode)
repo := &dashboardUsageRepoCacheProbe{}
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
}
package admin
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
var (
dashboardTrendCache = newSnapshotCache(30 * time.Second)
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
)
type dashboardTrendCacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Model string `json:"model"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
}
type dashboardModelGroupCacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
}
type dashboardEntityTrendCacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
Limit int `json:"limit"`
}
func cacheStatusValue(hit bool) string {
if hit {
return "hit"
}
return "miss"
}
func mustMarshalDashboardCacheKey(value any) string {
raw, err := json.Marshal(value)
if err != nil {
return ""
}
return string(raw)
}
func snapshotPayloadAs[T any](payload any) (T, error) {
typed, ok := payload.(T)
if !ok {
var zero T
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
}
return typed, nil
}
func (h *DashboardHandler) getUsageTrendCached(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
})
if err != nil {
return nil, hit, err
}
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
return trend, hit, err
}
func (h *DashboardHandler) getModelStatsCached(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.ModelStat, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
})
if err != nil {
return nil, hit, err
}
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
return stats, hit, err
}
func (h *DashboardHandler) getGroupStatsCached(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.GroupStat, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
})
if err != nil {
return nil, hit, err
}
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
return stats, hit, err
}
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
Limit: limit,
})
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
})
if err != nil {
return nil, hit, err
}
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
return trend, hit, err
}
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
Limit: limit,
})
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
})
if err != nil {
return nil, hit, err
}
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
return trend, hit, err
}
package admin
import (
"context"
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
......@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
})
cacheKey := string(keyRaw)
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
return h.buildSnapshotV2Response(
c.Request.Context(),
startTime,
endTime,
granularity,
filters,
includeStats,
includeTrend,
includeModels,
includeGroups,
includeUsersTrend,
usersTrendLimit,
)
})
if err != nil {
response.Error(c, 500, err.Error())
return
}
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, cached.Payload)
}
func (h *DashboardHandler) buildSnapshotV2Response(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
filters *dashboardSnapshotV2Filters,
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
usersTrendLimit int,
) (*dashboardSnapshotV2Response, error) {
resp := &dashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
StartDate: startTime.Format("2006-01-02"),
......@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
}
if includeStats {
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
stats, err := h.dashboardService.GetDashboardStats(ctx)
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
return nil, errors.New("failed to get dashboard statistics")
}
resp.Stats = &dashboardSnapshotV2Stats{
DashboardStats: *stats,
......@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
}
if includeTrend {
trend, err := h.dashboardService.GetUsageTrendWithFilters(
c.Request.Context(),
trend, _, err := h.getUsageTrendCached(
ctx,
startTime,
endTime,
granularity,
......@@ -160,15 +186,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
return nil, errors.New("failed to get usage trend")
}
resp.Trend = trend
}
if includeModels {
models, err := h.dashboardService.GetModelStatsWithFilters(
c.Request.Context(),
models, _, err := h.getModelStatsCached(
ctx,
startTime,
endTime,
filters.UserID,
......@@ -180,15 +205,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
return nil, errors.New("failed to get model statistics")
}
resp.Models = models
}
if includeGroups {
groups, err := h.dashboardService.GetGroupStatsWithFilters(
c.Request.Context(),
groups, _, err := h.getGroupStatsCached(
ctx,
startTime,
endTime,
filters.UserID,
......@@ -200,34 +224,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
return nil, errors.New("failed to get group statistics")
}
resp.Groups = groups
}
if includeUsersTrend {
usersTrend, err := h.dashboardService.GetUserUsageTrend(
c.Request.Context(),
startTime,
endTime,
granularity,
usersTrendLimit,
)
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
return nil, errors.New("failed to get user usage trend")
}
resp.UsersTrend = usersTrend
}
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
return resp, nil
}
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
......
......@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
"cpu_usage_percent",
"memory_usage_percent",
"concurrency_queue_depth",
"group_available_accounts",
"group_available_ratio",
"group_rate_limit_ratio",
"account_rate_limited_count",
"account_error_count",
"account_error_ratio",
"overload_account_count",
}
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
......@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
"error_rate",
"upstream_error_rate",
"cpu_usage_percent",
"memory_usage_percent":
"memory_usage_percent",
"group_available_ratio",
"group_rate_limit_ratio",
"account_error_ratio":
return true
default:
return false
......
......@@ -7,6 +7,8 @@ import (
"strings"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
type snapshotCacheEntry struct {
......@@ -19,6 +21,12 @@ type snapshotCache struct {
mu sync.RWMutex
ttl time.Duration
items map[string]snapshotCacheEntry
sf singleflight.Group
}
type snapshotCacheLoadResult struct {
Entry snapshotCacheEntry
Hit bool
}
func newSnapshotCache(ttl time.Duration) *snapshotCache {
......@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
return entry
}
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
if load == nil {
return snapshotCacheEntry{}, false, nil
}
if entry, ok := c.Get(key); ok {
return entry, true, nil
}
if c == nil || key == "" {
payload, err := load()
if err != nil {
return snapshotCacheEntry{}, false, err
}
return c.Set(key, payload), false, nil
}
value, err, _ := c.sf.Do(key, func() (any, error) {
if entry, ok := c.Get(key); ok {
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
}
payload, err := load()
if err != nil {
return nil, err
}
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
})
if err != nil {
return snapshotCacheEntry{}, false, err
}
result, ok := value.(snapshotCacheLoadResult)
if !ok {
return snapshotCacheEntry{}, false, nil
}
return result.Entry, result.Hit, nil
}
func buildETagFromAny(payload any) string {
raw, err := json.Marshal(payload)
if err != nil {
......
......@@ -3,6 +3,8 @@
package admin
import (
"sync"
"sync/atomic"
"testing"
"time"
......@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
require.Empty(t, etag)
}
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
var loads atomic.Int32
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
loads.Add(1)
return map[string]string{"hello": "world"}, nil
})
require.NoError(t, err)
require.False(t, hit)
require.NotEmpty(t, entry.ETag)
require.Equal(t, int32(1), loads.Load())
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
loads.Add(1)
return map[string]string{"unexpected": "value"}, nil
})
require.NoError(t, err)
require.True(t, hit)
require.Equal(t, entry.ETag, entry2.ETag)
require.Equal(t, int32(1), loads.Load())
}
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
var loads atomic.Int32
start := make(chan struct{})
const callers = 8
errCh := make(chan error, callers)
var wg sync.WaitGroup
wg.Add(callers)
for range callers {
go func() {
defer wg.Done()
<-start
_, _, err := c.GetOrLoad("shared", func() (any, error) {
loads.Add(1)
time.Sleep(20 * time.Millisecond)
return "value", nil
})
errCh <- err
}()
}
close(start)
wg.Wait()
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
require.Equal(t, int32(1), loads.Load())
}
func TestParseBoolQueryWithDefault(t *testing.T) {
tests := []struct {
name string
......
......@@ -216,6 +216,37 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
})
}
// ResetSubscriptionQuotaRequest represents the reset quota request
type ResetSubscriptionQuotaRequest struct {
Daily bool `json:"daily"`
Weekly bool `json:"weekly"`
}
// ResetQuota resets daily and/or weekly usage for a subscription.
// POST /api/v1/admin/subscriptions/:id/reset-quota
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
var req ResetSubscriptionQuotaRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.Daily && !req.Weekly {
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
return
}
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
}
// Revoke handles revoking a subscription
// DELETE /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
......
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ChatCompletions handles OpenAI Chat Completions API requests.
// POST /v1/chat/completions
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted)
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
for {
c.Set("openai_chat_completions_fallback_model", "")
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err != nil {
reqLog.Warn("openai_chat_completions.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != reqModel {
reqLog.Info("openai_chat_completions.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)
}
}
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
} else {
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
} else {
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
}
return
}
}
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
account := selection.Account
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
_ = scheduleDecision
setOpsSelectedAccount(c, account.ID, account.Platform)
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := ""
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Warn("openai_chat_completions.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
} else {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("openai_chat_completions.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}
......@@ -31,6 +31,7 @@ const (
const (
opsErrorLogTimeout = 5 * time.Second
opsErrorLogDrainTimeout = 10 * time.Second
opsErrorLogBatchWindow = 200 * time.Millisecond
opsErrorLogMinWorkerCount = 4
opsErrorLogMaxWorkerCount = 32
......@@ -38,6 +39,7 @@ const (
opsErrorLogQueueSizePerWorker = 128
opsErrorLogMinQueueSize = 256
opsErrorLogMaxQueueSize = 8192
opsErrorLogBatchSize = 32
)
type opsErrorLogJob struct {
......@@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() {
for i := 0; i < workerCount; i++ {
go func() {
defer opsErrorLogWorkersWg.Done()
for job := range opsErrorLogQueue {
opsErrorLogQueueLen.Add(-1)
if job.ops == nil || job.entry == nil {
continue
for {
job, ok := <-opsErrorLogQueue
if !ok {
return
}
func() {
defer func() {
if r := recover(); r != nil {
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
opsErrorLogQueueLen.Add(-1)
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
batch = append(batch, job)
timer := time.NewTimer(opsErrorLogBatchWindow)
batchLoop:
for len(batch) < opsErrorLogBatchSize {
select {
case nextJob, ok := <-opsErrorLogQueue:
if !ok {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
flushOpsErrorLogBatch(batch)
return
}
}()
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
_ = job.ops.RecordError(ctx, job.entry, nil)
cancel()
opsErrorLogProcessed.Add(1)
}()
opsErrorLogQueueLen.Add(-1)
batch = append(batch, nextJob)
case <-timer.C:
break batchLoop
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
flushOpsErrorLogBatch(batch)
}
}()
}
}
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
if len(batch) == 0 {
return
}
defer func() {
if r := recover(); r != nil {
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
}
}()
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
var processed int64
for _, job := range batch {
if job.ops == nil || job.entry == nil {
continue
}
grouped[job.ops] = append(grouped[job.ops], job.entry)
processed++
}
if processed == 0 {
return
}
for opsSvc, entries := range grouped {
if opsSvc == nil || len(entries) == 0 {
continue
}
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
_ = opsSvc.RecordErrorBatch(ctx, entries)
cancel()
}
opsErrorLogProcessed.Add(processed)
}
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
if ops == nil || entry == nil {
return
......
......@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
// Antigravity 支持的 Gemini 模型
var geminiModels = []modelDef{
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
......
......@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
requiredIDs := []string{
"claude-opus-4-6-thinking",
"gemini-2.5-flash-image",
"gemini-2.5-flash-image-preview",
"gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview",
"gemini-3-pro-image", // legacy compatibility
......
package apicompat
import (
"encoding/json"
"fmt"
)
// ChatCompletionsToResponses converts a Chat Completions request into a
// Responses API request. The upstream always streams, so Stream is forced to
// true. store is always false and reasoning.encrypted_content is always
// included so that the response translator has full context.
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
input, err := convertChatMessagesToResponsesInput(req.Messages)
if err != nil {
return nil, err
}
inputJSON, err := json.Marshal(input)
if err != nil {
return nil, err
}
out := &ResponsesRequest{
Model: req.Model,
Input: inputJSON,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: true, // upstream always streams
Include: []string{"reasoning.encrypted_content"},
ServiceTier: req.ServiceTier,
}
storeFalse := false
out.Store = &storeFalse
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
maxTokens := 0
if req.MaxTokens != nil {
maxTokens = *req.MaxTokens
}
if req.MaxCompletionTokens != nil {
maxTokens = *req.MaxCompletionTokens
}
if maxTokens > 0 {
v := maxTokens
if v < minMaxOutputTokens {
v = minMaxOutputTokens
}
out.MaxOutputTokens = &v
}
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
if req.ReasoningEffort != "" {
out.Reasoning = &ResponsesReasoning{
Effort: req.ReasoningEffort,
Summary: "auto",
}
}
// tools[] and legacy functions[] → ResponsesTool[]
if len(req.Tools) > 0 || len(req.Functions) > 0 {
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
}
// tool_choice: already compatible format — pass through directly.
// Legacy function_call needs mapping.
if len(req.ToolChoice) > 0 {
out.ToolChoice = req.ToolChoice
} else if len(req.FunctionCall) > 0 {
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
if err != nil {
return nil, fmt.Errorf("convert function_call: %w", err)
}
out.ToolChoice = tc
}
return out, nil
}
// convertChatMessagesToResponsesInput converts the Chat Completions messages
// array into a Responses API input items array.
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
var out []ResponsesInputItem
for _, m := range msgs {
items, err := chatMessageToResponsesItems(m)
if err != nil {
return nil, err
}
out = append(out, items...)
}
return out, nil
}
// chatMessageToResponsesItems converts a single ChatMessage into one or more
// ResponsesInputItem values.
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
switch m.Role {
case "system":
return chatSystemToResponses(m)
case "user":
return chatUserToResponses(m)
case "assistant":
return chatAssistantToResponses(m)
case "tool":
return chatToolToResponses(m)
case "function":
return chatFunctionToResponses(m)
default:
return chatUserToResponses(m)
}
}
// chatSystemToResponses converts a system message.
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
text, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
content, err := json.Marshal(text)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
}
// chatUserToResponses converts a user message, handling both plain strings and
// multi-modal content arrays.
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// Try plain string first.
var s string
if err := json.Unmarshal(m.Content, &s); err == nil {
content, _ := json.Marshal(s)
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
var parts []ChatContentPart
if err := json.Unmarshal(m.Content, &parts); err != nil {
return nil, fmt.Errorf("parse user content: %w", err)
}
var responseParts []ResponsesContentPart
for _, p := range parts {
switch p.Type {
case "text":
if p.Text != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_text",
Text: p.Text,
})
}
case "image_url":
if p.ImageURL != nil && p.ImageURL.URL != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
})
}
}
}
content, err := json.Marshal(responseParts)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
// chatAssistantToResponses converts an assistant message. If there is both
// text content and tool_calls, the text is emitted as an assistant message
// first, then each tool_call becomes a function_call item. If the content is
// empty/nil and there are tool_calls, only function_call items are emitted.
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
var items []ResponsesInputItem
// Emit assistant message with output_text if content is non-empty.
if len(m.Content) > 0 {
var s string
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
partsJSON, err := json.Marshal(parts)
if err != nil {
return nil, err
}
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
}
}
// Emit one function_call item per tool_call.
for _, tc := range m.ToolCalls {
args := tc.Function.Arguments
if args == "" {
args = "{}"
}
items = append(items, ResponsesInputItem{
Type: "function_call",
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: args,
ID: tc.ID,
})
}
return items, nil
}
// chatToolToResponses converts a tool result message (role=tool) into a
// function_call_output item.
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
output, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
if output == "" {
output = "(empty)"
}
return []ResponsesInputItem{{
Type: "function_call_output",
CallID: m.ToolCallID,
Output: output,
}}, nil
}
// chatFunctionToResponses converts a legacy function result message
// (role=function) into a function_call_output item. The Name field is used as
// call_id since legacy function calls do not carry a separate call_id.
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
output, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
if output == "" {
output = "(empty)"
}
return []ResponsesInputItem{{
Type: "function_call_output",
CallID: m.Name,
Output: output,
}}, nil
}
// parseChatContent returns the string value of a ChatMessage Content field.
// Content must be a JSON string. Returns "" if content is null or empty.
func parseChatContent(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", nil
}
var s string
if err := json.Unmarshal(raw, &s); err != nil {
return "", fmt.Errorf("parse content as string: %w", err)
}
return s, nil
}
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
// function definitions to Responses API tool definitions.
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
var out []ResponsesTool
for _, t := range tools {
if t.Type != "function" || t.Function == nil {
continue
}
rt := ResponsesTool{
Type: "function",
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: t.Function.Parameters,
Strict: t.Function.Strict,
}
out = append(out, rt)
}
// Legacy functions[] are treated as function-type tools.
for _, f := range functions {
rt := ResponsesTool{
Type: "function",
Name: f.Name,
Description: f.Description,
Parameters: f.Parameters,
Strict: f.Strict,
}
out = append(out, rt)
}
return out
}
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
// Responses API tool_choice value.
//
// "auto" → "auto"
// "none" → "none"
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try string first ("auto", "none", etc.) — pass through as-is.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal(s)
}
// Object form: {"name":"X"}
var obj struct {
Name string `json:"name"`
}
if err := json.Unmarshal(raw, &obj); err != nil {
return nil, err
}
return json.Marshal(map[string]any{
"type": "function",
"function": map[string]string{"name": obj.Name},
})
}
package apicompat
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"time"
)
// ---------------------------------------------------------------------------
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
// ---------------------------------------------------------------------------
// ResponsesToChatCompletions converts a Responses API response into a Chat
// Completions response. Text output items are concatenated into
// choices[0].message.content; function_call items become tool_calls.
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
id := resp.ID
if id == "" {
id = generateChatCmplID()
}
out := &ChatCompletionsResponse{
ID: id,
Object: "chat.completion",
Created: time.Now().Unix(),
Model: model,
}
var contentText string
var toolCalls []ChatToolCall
for _, item := range resp.Output {
switch item.Type {
case "message":
for _, part := range item.Content {
if part.Type == "output_text" && part.Text != "" {
contentText += part.Text
}
}
case "function_call":
toolCalls = append(toolCalls, ChatToolCall{
ID: item.CallID,
Type: "function",
Function: ChatFunctionCall{
Name: item.Name,
Arguments: item.Arguments,
},
})
case "reasoning":
for _, s := range item.Summary {
if s.Type == "summary_text" && s.Text != "" {
contentText += s.Text
}
}
case "web_search_call":
// silently consumed — results already incorporated into text output
}
}
msg := ChatMessage{Role: "assistant"}
if len(toolCalls) > 0 {
msg.ToolCalls = toolCalls
}
if contentText != "" {
raw, _ := json.Marshal(contentText)
msg.Content = raw
}
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
out.Choices = []ChatChoice{{
Index: 0,
Message: msg,
FinishReason: finishReason,
}}
if resp.Usage != nil {
usage := &ChatUsage{
PromptTokens: resp.Usage.InputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
}
}
out.Usage = usage
}
return out
}
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
switch status {
case "incomplete":
if details != nil && details.Reason == "max_output_tokens" {
return "length"
}
return "stop"
case "completed":
if len(toolCalls) > 0 {
return "tool_calls"
}
return "stop"
default:
return "stop"
}
}
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
// ---------------------------------------------------------------------------
// ResponsesEventToChatState tracks state for converting a sequence of Responses
// SSE events into Chat Completions SSE chunks.
type ResponsesEventToChatState struct {
ID string
Model string
Created int64
SentRole bool
SawToolCall bool
SawText bool
Finalized bool // true after finish chunk has been emitted
NextToolCallIndex int // next sequential tool_call index to assign
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
IncludeUsage bool
Usage *ChatUsage
}
// NewResponsesEventToChatState returns an initialised stream state.
func NewResponsesEventToChatState() *ResponsesEventToChatState {
return &ResponsesEventToChatState{
ID: generateChatCmplID(),
Created: time.Now().Unix(),
OutputIndexToToolIndex: make(map[int]int),
}
}
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
// or more Chat Completions chunks, updating state as it goes.
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
switch evt.Type {
case "response.created":
return resToChatHandleCreated(evt, state)
case "response.output_text.delta":
return resToChatHandleTextDelta(evt, state)
case "response.output_item.added":
return resToChatHandleOutputItemAdded(evt, state)
case "response.function_call_arguments.delta":
return resToChatHandleFuncArgsDelta(evt, state)
case "response.reasoning_summary_text.delta":
return resToChatHandleReasoningDelta(evt, state)
case "response.completed", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state)
default:
return nil
}
}
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
// stream ended without a proper completion event (e.g. upstream disconnect).
// It is idempotent: if a completion event already emitted the finish chunk,
// this returns nil.
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
if state.Finalized {
return nil
}
state.Finalized = true
finishReason := "stop"
if state.SawToolCall {
finishReason = "tool_calls"
}
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
if state.IncludeUsage && state.Usage != nil {
chunks = append(chunks, ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{},
Usage: state.Usage,
})
}
return chunks
}
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
data, err := json.Marshal(chunk)
if err != nil {
return "", err
}
return fmt.Sprintf("data: %s\n\n", data), nil
}
// --- internal handlers ---
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Response != nil {
if evt.Response.ID != "" {
state.ID = evt.Response.ID
}
if state.Model == "" && evt.Response.Model != "" {
state.Model = evt.Response.Model
}
}
// Emit the role chunk.
if state.SentRole {
return nil
}
state.SentRole = true
role := "assistant"
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
}
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
state.SawText = true
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
}
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Item == nil || evt.Item.Type != "function_call" {
return nil
}
state.SawToolCall = true
idx := state.NextToolCallIndex
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
state.NextToolCallIndex++
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
ToolCalls: []ChatToolCall{{
Index: &idx,
ID: evt.Item.CallID,
Type: "function",
Function: ChatFunctionCall{
Name: evt.Item.Name,
},
}},
})}
}
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
if !ok {
return nil
}
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
ToolCalls: []ChatToolCall{{
Index: &idx,
Function: ChatFunctionCall{
Arguments: evt.Delta,
},
}},
})}
}
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
}
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
state.Finalized = true
finishReason := "stop"
if evt.Response != nil {
if evt.Response.Usage != nil {
u := evt.Response.Usage
usage := &ChatUsage{
PromptTokens: u.InputTokens,
CompletionTokens: u.OutputTokens,
TotalTokens: u.InputTokens + u.OutputTokens,
}
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: u.InputTokensDetails.CachedTokens,
}
}
state.Usage = usage
}
switch evt.Response.Status {
case "incomplete":
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
finishReason = "length"
}
case "completed":
if state.SawToolCall {
finishReason = "tool_calls"
}
}
} else if state.SawToolCall {
finishReason = "tool_calls"
}
var chunks []ChatCompletionsChunk
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
if state.IncludeUsage && state.Usage != nil {
chunks = append(chunks, ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{},
Usage: state.Usage,
})
}
return chunks
}
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
return ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{{
Index: 0,
Delta: delta,
FinishReason: nil,
}},
}
}
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
empty := ""
return ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{{
Index: 0,
Delta: ChatDelta{Content: &empty},
FinishReason: &finishReason,
}},
}
}
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
func generateChatCmplID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "chatcmpl-" + hex.EncodeToString(b)
}
......@@ -329,6 +329,148 @@ type ResponsesStreamEvent struct {
SequenceNumber int `json:"sequence_number,omitempty"`
}
// ---------------------------------------------------------------------------
// OpenAI Chat Completions API types
// ---------------------------------------------------------------------------
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
type ChatCompletionsRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Tools []ChatTool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
ServiceTier string `json:"service_tier,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
// Legacy function calling (deprecated but still supported)
Functions []ChatFunction `json:"functions,omitempty"`
FunctionCall json.RawMessage `json:"function_call,omitempty"`
}
// ChatStreamOptions configures streaming behavior.
type ChatStreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
// ChatMessage is a single message in the Chat Completions conversation.
type ChatMessage struct {
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
Content json.RawMessage `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
// Legacy function calling
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
}
// ChatContentPart is a typed content part in a multi-modal message.
type ChatContentPart struct {
Type string `json:"type"` // "text" | "image_url"
Text string `json:"text,omitempty"`
ImageURL *ChatImageURL `json:"image_url,omitempty"`
}
// ChatImageURL contains the URL for an image content part.
type ChatImageURL struct {
URL string `json:"url"`
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
}
// ChatTool describes a tool available to the model.
type ChatTool struct {
Type string `json:"type"` // "function"
Function *ChatFunction `json:"function,omitempty"`
}
// ChatFunction describes a function tool definition.
type ChatFunction struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
// ChatToolCall represents a tool call made by the assistant.
// Index is only populated in streaming chunks (omitted in non-streaming responses).
type ChatToolCall struct {
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"` // "function"
Function ChatFunctionCall `json:"function"`
}
// ChatFunctionCall contains the function name and arguments.
type ChatFunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
type ChatCompletionsResponse struct {
ID string `json:"id"`
Object string `json:"object"` // "chat.completion"
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChoice `json:"choices"`
Usage *ChatUsage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
}
// ChatChoice is a single completion choice.
type ChatChoice struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
}
// ChatUsage holds token counts in Chat Completions format.
type ChatUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
}
// ChatTokenDetails provides a breakdown of token usage.
type ChatTokenDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`
}
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
type ChatCompletionsChunk struct {
ID string `json:"id"`
Object string `json:"object"` // "chat.completion.chunk"
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChunkChoice `json:"choices"`
Usage *ChatUsage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
}
// ChatChunkChoice is a single choice in a streaming chunk.
type ChatChunkChoice struct {
Index int `json:"index"`
Delta ChatDelta `json:"delta"`
FinishReason *string `json:"finish_reason"` // pointer: null when not final
}
// ChatDelta carries incremental content in a streaming chunk.
type ChatDelta struct {
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
}
// ---------------------------------------------------------------------------
// Shared constants
// ---------------------------------------------------------------------------
......
......@@ -18,10 +18,12 @@ func DefaultModels() []Model {
return []Model{
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment