"vscode:/vscode.git/clone" did not exist on "dd5978f2224ce6646f4cdb5e3693f7059f26ab90"
Commit 1de18b89 authored by Wang Lvyuan's avatar Wang Lvyuan
Browse files

merge: sync upstream/main before PR

parents 882518c1 9f6ab6b8
package service
import (
"encoding/json"
"regexp"
"strings"
)
// NewMetadataFormatMinVersion is the minimum Claude Code version that uses
// JSON-formatted metadata.user_id instead of the legacy concatenated string.
const NewMetadataFormatMinVersion = "2.1.78"
// ParsedUserID represents the components extracted from a metadata.user_id value.
type ParsedUserID struct {
DeviceID string // 64-char hex (or arbitrary client id)
AccountUUID string // may be empty
SessionID string // UUID
IsNewFormat bool // true if the original was JSON format
}
// legacyUserIDRegex matches the legacy user_id format:
//
// user_{64hex}_account_{optional_uuid}_session_{uuid}
var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`)
// jsonUserID is the JSON structure for the new metadata.user_id format.
type jsonUserID struct {
DeviceID string `json:"device_id"`
AccountUUID string `json:"account_uuid"`
SessionID string `json:"session_id"`
}
// ParseMetadataUserID parses a metadata.user_id string in either format.
// Returns nil if the input cannot be parsed.
func ParseMetadataUserID(raw string) *ParsedUserID {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
// Try JSON format first (starts with '{')
if raw[0] == '{' {
var j jsonUserID
if err := json.Unmarshal([]byte(raw), &j); err != nil {
return nil
}
if j.DeviceID == "" || j.SessionID == "" {
return nil
}
return &ParsedUserID{
DeviceID: j.DeviceID,
AccountUUID: j.AccountUUID,
SessionID: j.SessionID,
IsNewFormat: true,
}
}
// Try legacy format
matches := legacyUserIDRegex.FindStringSubmatch(raw)
if matches == nil {
return nil
}
return &ParsedUserID{
DeviceID: matches[1],
AccountUUID: matches[2],
SessionID: matches[3],
IsNewFormat: false,
}
}
// FormatMetadataUserID builds a metadata.user_id string in the format
// appropriate for the given CLI version. Components are the rewritten values
// (not necessarily the originals).
func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string {
if IsNewMetadataFormatVersion(uaVersion) {
b, _ := json.Marshal(jsonUserID{
DeviceID: deviceID,
AccountUUID: accountUUID,
SessionID: sessionID,
})
return string(b)
}
// Legacy format
return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID
}
// IsNewMetadataFormatVersion returns true if the given CLI version uses the
// new JSON metadata.user_id format (>= 2.1.78).
func IsNewMetadataFormatVersion(version string) bool {
if version == "" {
return false
}
return CompareVersions(version, NewMetadataFormatMinVersion) >= 0
}
// ExtractCLIVersion extracts the Claude Code version from a User-Agent string.
// Returns "" if the UA doesn't match the expected pattern.
func ExtractCLIVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ============ ParseMetadataUserID Tests ============
func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_InvalidInputs(t *testing.T) {
tests := []struct {
name string
raw string
}{
{"empty string", ""},
{"whitespace only", " "},
{"random text", "not-a-valid-user-id"},
{"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"},
{"invalid JSON", `{"device_id":}`},
{"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`},
{"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`},
{"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw)
})
}
}
func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) {
// Legacy format should accept both upper and lower case hex
rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(rawUpper)
require.NotNil(t, parsed, "legacy format should accept uppercase hex")
require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID)
}
// ============ FormatMetadataUserID Tests ============
func TestFormatMetadataUserID_LegacyVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result)
}
func TestFormatMetadataUserID_NewVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78")
require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result)
}
func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result)
}
func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) {
// Legacy format with empty account UUID → double underscore
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22")
require.Contains(t, result, "_account__session_")
// New format with empty account UUID → empty string in JSON
result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78")
require.Contains(t, result, `"account_uuid":""`)
}
// ============ IsNewMetadataFormatVersion Tests ============
func TestIsNewMetadataFormatVersion(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"", false},
{"2.1.77", false},
{"2.1.78", true},
{"2.1.79", true},
{"2.2.0", true},
{"3.0.0", true},
{"2.0.100", false},
{"1.9.99", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version))
})
}
}
// ============ Round-trip Tests ============
func TestParseFormat_RoundTrip_Legacy(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_JSON(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
// Legacy round-trip with empty account UUID
formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
// JSON round-trip with empty account UUID
formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78")
parsed = ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
}
...@@ -277,12 +277,13 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( ...@@ -277,12 +277,13 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
c.JSON(http.StatusOK, chatResp) c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
Duration: time.Since(startTime),
}, nil }, nil
} }
...@@ -324,13 +325,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ...@@ -324,13 +325,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
Stream: true, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: true,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
} }
} }
......
...@@ -299,12 +299,13 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( ...@@ -299,12 +299,13 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
c.JSON(http.StatusOK, anthropicResp) c.JSON(http.StatusOK, anthropicResp)
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
Duration: time.Since(startTime),
}, nil }, nil
} }
...@@ -347,13 +348,14 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( ...@@ -347,13 +348,14 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
// resultWithUsage builds the final result snapshot. // resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
Stream: true, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: true,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
} }
} }
......
...@@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) { ...@@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Nil(t, extractOpenAIServiceTierFromBody(nil)) require.Nil(t, extractOpenAIServiceTierFromBody(nil))
} }
func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetadataFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{}
...@@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te ...@@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
RequestID: "resp_billing_model_override", RequestID: "resp_billing_model_override",
BillingModel: "gpt-5.1-codex", BillingModel: "gpt-5.1-codex",
Model: "gpt-5.1", Model: "gpt-5.1",
UpstreamModel: "gpt-5.1-codex",
ServiceTier: &serviceTier, ServiceTier: &serviceTier,
ReasoningEffort: &reasoning, ReasoningEffort: &reasoning,
Usage: OpenAIUsage{ Usage: OpenAIUsage{
...@@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te ...@@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog) require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
require.NotNil(t, usageRepo.lastLog.ServiceTier) require.NotNil(t, usageRepo.lastLog.ServiceTier)
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
require.NotNil(t, usageRepo.lastLog.ReasoningEffort) require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
......
...@@ -216,6 +216,9 @@ type OpenAIForwardResult struct { ...@@ -216,6 +216,9 @@ type OpenAIForwardResult struct {
// This is set by the Anthropic Messages conversion path where // This is set by the Anthropic Messages conversion path where
// the mapped upstream model differs from the client-facing model. // the mapped upstream model differs from the client-facing model.
BillingModel string BillingModel string
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Empty when no mapping was applied (requested model was used as-is).
UpstreamModel string
// ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex".
// Nil means the request did not specify a recognized tier. // Nil means the request did not specify a recognized tier.
ServiceTier *string ServiceTier *string
...@@ -2128,6 +2131,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -2128,6 +2131,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
firstTokenMs, firstTokenMs,
wsAttempts, wsAttempts,
) )
wsResult.UpstreamModel = mappedModel
return wsResult, nil return wsResult, nil
} }
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
...@@ -2263,6 +2267,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -2263,6 +2267,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: serviceTier, ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort, ReasoningEffort: reasoningEffort,
Stream: reqStream, Stream: reqStream,
...@@ -4134,7 +4139,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4134,7 +4139,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: billingModel, Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier, ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
...@@ -4700,11 +4706,3 @@ func normalizeOpenAIReasoningEffort(raw string) string { ...@@ -4700,11 +4706,3 @@ func normalizeOpenAIReasoningEffort(raw string) string {
return "" return ""
} }
} }
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// errSettingRepo: a SettingRepository that always returns errors on read
// ---------------------------------------------------------------------------
type errSettingRepo struct {
mockSettingRepo // embed the existing mock from backup_service_test.go
readErr error
}
func (r *errSettingRepo) GetValue(_ context.Context, _ string) (string, error) {
return "", r.readErr
}
func (r *errSettingRepo) Get(_ context.Context, _ string) (*Setting, error) {
return nil, r.readErr
}
// ---------------------------------------------------------------------------
// overloadAccountRepoStub: records SetOverloaded calls
// ---------------------------------------------------------------------------
type overloadAccountRepoStub struct {
mockAccountRepoForGemini
overloadCalls int
lastOverloadID int64
lastOverloadEnd time.Time
}
func (r *overloadAccountRepoStub) SetOverloaded(_ context.Context, id int64, until time.Time) error {
r.overloadCalls++
r.lastOverloadID = id
r.lastOverloadEnd = until
return nil
}
// ===========================================================================
// SettingService: GetOverloadCooldownSettings
// ===========================================================================
func TestGetOverloadCooldownSettings_DefaultsWhenNotSet(t *testing.T) {
repo := newMockSettingRepo()
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_ReadsFromDB(t *testing.T) {
repo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 30})
repo.data[SettingKeyOverloadCooldownSettings] = string(data)
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.Enabled)
require.Equal(t, 30, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_ClampsMinValue(t *testing.T) {
repo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 0})
repo.data[SettingKeyOverloadCooldownSettings] = string(data)
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 1, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_ClampsMaxValue(t *testing.T) {
repo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 999})
repo.data[SettingKeyOverloadCooldownSettings] = string(data)
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 120, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_InvalidJSON_ReturnsDefaults(t *testing.T) {
repo := newMockSettingRepo()
repo.data[SettingKeyOverloadCooldownSettings] = "not-json"
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_EmptyValue_ReturnsDefaults(t *testing.T) {
repo := newMockSettingRepo()
repo.data[SettingKeyOverloadCooldownSettings] = ""
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes)
}
// ===========================================================================
// SettingService: SetOverloadCooldownSettings
// ===========================================================================
func TestSetOverloadCooldownSettings_Success(t *testing.T) {
repo := newMockSettingRepo()
svc := NewSettingService(repo, &config.Config{})
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: false,
CooldownMinutes: 25,
})
require.NoError(t, err)
// Verify round-trip
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.Enabled)
require.Equal(t, 25, settings.CooldownMinutes)
}
func TestSetOverloadCooldownSettings_RejectsNil(t *testing.T) {
svc := NewSettingService(newMockSettingRepo(), &config.Config{})
err := svc.SetOverloadCooldownSettings(context.Background(), nil)
require.Error(t, err)
}
func TestSetOverloadCooldownSettings_EnabledRejectsOutOfRange(t *testing.T) {
svc := NewSettingService(newMockSettingRepo(), &config.Config{})
for _, minutes := range []int{0, -1, 121, 999} {
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: true, CooldownMinutes: minutes,
})
require.Error(t, err, "should reject enabled=true + cooldown_minutes=%d", minutes)
require.Contains(t, err.Error(), "cooldown_minutes must be between 1-120")
}
}
func TestSetOverloadCooldownSettings_DisabledNormalizesOutOfRange(t *testing.T) {
repo := newMockSettingRepo()
svc := NewSettingService(repo, &config.Config{})
// enabled=false + cooldown_minutes=0 应该保存成功,值被归一化为10
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: false, CooldownMinutes: 0,
})
require.NoError(t, err, "disabled with invalid minutes should NOT be rejected")
// 验证持久化后读回来的值
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes, "should be normalized to default")
}
func TestSetOverloadCooldownSettings_AcceptsBoundaries(t *testing.T) {
svc := NewSettingService(newMockSettingRepo(), &config.Config{})
for _, minutes := range []int{1, 60, 120} {
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: true, CooldownMinutes: minutes,
})
require.NoError(t, err, "should accept cooldown_minutes=%d", minutes)
}
}
// ===========================================================================
// RateLimitService: handle529 behaviour
// ===========================================================================
func TestHandle529_EnabledFromDB_PausesAccount(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
settingRepo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 15})
settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data)
settingSvc := NewSettingService(settingRepo, &config.Config{})
svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil)
svc.SetSettingService(settingSvc)
account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.Equal(t, int64(42), accountRepo.lastOverloadID)
require.WithinDuration(t, before.Add(15*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
func TestHandle529_DisabledFromDB_SkipsAccount(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
settingRepo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 15})
settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data)
settingSvc := NewSettingService(settingRepo, &config.Config{})
svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil)
svc.SetSettingService(settingSvc)
account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
svc.handle529(context.Background(), account)
require.Equal(t, 0, accountRepo.overloadCalls, "should NOT pause when disabled")
}
func TestHandle529_NilSettingService_FallsBackToConfig(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
cfg := &config.Config{}
cfg.RateLimit.OverloadCooldownMinutes = 20
svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil)
// NOT calling SetSettingService — remains nil
account := &Account{ID: 77, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.WithinDuration(t, before.Add(20*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
func TestHandle529_NilSettingService_ZeroConfig_DefaultsTen(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil)
account := &Account{ID: 88, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.WithinDuration(t, before.Add(10*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
func TestHandle529_DBReadError_FallsBackToConfig(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
errRepo := &errSettingRepo{readErr: context.DeadlineExceeded}
errRepo.data = make(map[string]string)
cfg := &config.Config{}
cfg.RateLimit.OverloadCooldownMinutes = 7
settingSvc := NewSettingService(errRepo, cfg)
svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil)
svc.SetSettingService(settingSvc)
account := &Account{ID: 99, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.WithinDuration(t, before.Add(7*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
// ===========================================================================
// Model: defaults & JSON round-trip
// ===========================================================================
func TestDefaultOverloadCooldownSettings(t *testing.T) {
d := DefaultOverloadCooldownSettings()
require.True(t, d.Enabled)
require.Equal(t, 10, d.CooldownMinutes)
}
func TestOverloadCooldownSettings_JSONRoundTrip(t *testing.T) {
original := OverloadCooldownSettings{Enabled: false, CooldownMinutes: 42}
data, err := json.Marshal(original)
require.NoError(t, err)
var decoded OverloadCooldownSettings
require.NoError(t, json.Unmarshal(data, &decoded))
require.Equal(t, original, decoded)
// Verify JSON uses snake_case field names
var raw map[string]any
require.NoError(t, json.Unmarshal(data, &raw))
_, hasEnabled := raw["enabled"]
_, hasCooldown := raw["cooldown_minutes"]
require.True(t, hasEnabled, "JSON must use 'enabled'")
require.True(t, hasCooldown, "JSON must use 'cooldown_minutes'")
}
...@@ -1023,11 +1023,34 @@ func parseOpenAIRateLimitResetTime(body []byte) *int64 { ...@@ -1023,11 +1023,34 @@ func parseOpenAIRateLimitResetTime(body []byte) *int64 {
} }
// handle529 处理529过载错误 // handle529 处理529过载错误
// 根据配置设置过载冷却时 // 根据配置决定是否暂停账号调度及冷却时
func (s *RateLimitService) handle529(ctx context.Context, account *Account) { func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes var settings *OverloadCooldownSettings
if s.settingService != nil {
var err error
settings, err = s.settingService.GetOverloadCooldownSettings(ctx)
if err != nil {
slog.Warn("overload_settings_read_failed", "account_id", account.ID, "error", err)
settings = nil
}
}
// 回退到配置文件
if settings == nil {
cooldown := s.cfg.RateLimit.OverloadCooldownMinutes
if cooldown <= 0 {
cooldown = 10
}
settings = &OverloadCooldownSettings{Enabled: true, CooldownMinutes: cooldown}
}
if !settings.Enabled {
slog.Info("account_529_ignored", "account_id", account.ID, "reason", "overload_cooldown_disabled")
return
}
cooldownMinutes := settings.CooldownMinutes
if cooldownMinutes <= 0 { if cooldownMinutes <= 0 {
cooldownMinutes = 10 // 默认10分钟 cooldownMinutes = 10
} }
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
......
...@@ -1172,6 +1172,57 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf ...@@ -1172,6 +1172,57 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return effective, nil return effective, nil
} }
// GetOverloadCooldownSettings 获取529过载冷却配置
func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultOverloadCooldownSettings(), nil
}
return nil, fmt.Errorf("get overload cooldown settings: %w", err)
}
if value == "" {
return DefaultOverloadCooldownSettings(), nil
}
var settings OverloadCooldownSettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
return DefaultOverloadCooldownSettings(), nil
}
// 修正配置值范围
if settings.CooldownMinutes < 1 {
settings.CooldownMinutes = 1
}
if settings.CooldownMinutes > 120 {
settings.CooldownMinutes = 120
}
return &settings, nil
}
// SetOverloadCooldownSettings 设置529过载冷却配置
func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settings *OverloadCooldownSettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
// 禁用时修正为合法值即可,不拒绝请求
if settings.CooldownMinutes < 1 || settings.CooldownMinutes > 120 {
if settings.Enabled {
return fmt.Errorf("cooldown_minutes must be between 1-120")
}
settings.CooldownMinutes = 10 // 禁用状态下归一化为默认值
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal overload cooldown settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data))
}
// GetStreamTimeoutSettings 获取流超时处理配置 // GetStreamTimeoutSettings 获取流超时处理配置
func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) { func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings) value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings)
......
...@@ -222,6 +222,22 @@ type BetaPolicySettings struct { ...@@ -222,6 +222,22 @@ type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"` Rules []BetaPolicyRule `json:"rules"`
} }
// OverloadCooldownSettings 529过载冷却配置
type OverloadCooldownSettings struct {
// Enabled 是否在收到529时暂停账号调度
Enabled bool `json:"enabled"`
// CooldownMinutes 冷却时长(分钟)
CooldownMinutes int `json:"cooldown_minutes"`
}
// DefaultOverloadCooldownSettings 返回默认的过载冷却配置(启用,10分钟)
func DefaultOverloadCooldownSettings() *OverloadCooldownSettings {
return &OverloadCooldownSettings{
Enabled: true,
CooldownMinutes: 10,
}
}
// DefaultBetaPolicySettings 返回默认的 Beta 策略配置 // DefaultBetaPolicySettings 返回默认的 Beta 策略配置
func DefaultBetaPolicySettings() *BetaPolicySettings { func DefaultBetaPolicySettings() *BetaPolicySettings {
return &BetaPolicySettings{ return &BetaPolicySettings{
......
...@@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([ ...@@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([
func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) { func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) {
return false, nil return false, nil
} }
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) { func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil return 0, nil
......
...@@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err ...@@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err
func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) { func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) { func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
...@@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri ...@@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri
func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) { func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call") panic("unexpected ListByGroupID call")
} }
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected List call") panic("unexpected List call")
} }
func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) { func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {
......
...@@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI ...@@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
} }
// List 获取所有订阅(分页,支持筛选和排序) // List 获取所有订阅(分页,支持筛选和排序)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder) subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
......
...@@ -98,6 +98,9 @@ type UsageLog struct { ...@@ -98,6 +98,9 @@ type UsageLog struct {
AccountID int64 AccountID int64
RequestID string RequestID string
Model string Model string
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string ServiceTier *string
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.
......
package service
import "strings"
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}
// optionalNonEqualStringPtr returns a pointer to value if it is non-empty and
// differs from compare; otherwise nil. Used to store upstream_model only when
// it differs from the requested model.
func optionalNonEqualStringPtr(value, compare string) *string {
if value == "" || value == compare {
return nil
}
return &value
}
...@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface { ...@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
......
...@@ -486,4 +486,5 @@ var ProviderSet = wire.NewSet( ...@@ -486,4 +486,5 @@ var ProviderSet = wire.NewSet(
ProvideIdempotencyCleanupService, ProvideIdempotencyCleanupService,
ProvideScheduledTestService, ProvideScheduledTestService,
ProvideScheduledTestRunnerService, ProvideScheduledTestRunnerService,
NewGroupCapacityService,
) )
...@@ -247,6 +247,12 @@ func install(c *gin.Context) { ...@@ -247,6 +247,12 @@ func install(c *gin.Context) {
return return
} }
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
req.Database.Host = strings.TrimSpace(req.Database.Host)
req.Database.User = strings.TrimSpace(req.Database.User)
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
// ========== COMPREHENSIVE INPUT VALIDATION ========== // ========== COMPREHENSIVE INPUT VALIDATION ==========
// Database validation // Database validation
if !validateHostname(req.Database.Host) { if !validateHostname(req.Database.Host) {
...@@ -319,13 +325,6 @@ func install(c *gin.Context) { ...@@ -319,13 +325,6 @@ func install(c *gin.Context) {
return return
} }
// Trim whitespace from string inputs
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
req.Database.Host = strings.TrimSpace(req.Database.Host)
req.Database.User = strings.TrimSpace(req.Database.User)
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
cfg := &SetupConfig{ cfg := &SetupConfig{
Database: req.Database, Database: req.Database,
Redis: req.Redis, Redis: req.Redis,
......
...@@ -180,7 +180,37 @@ func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte { ...@@ -180,7 +180,37 @@ func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte {
// Inject before </head> // Inject before </head>
headClose := []byte("</head>") headClose := []byte("</head>")
return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1) result := bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1)
// Replace <title> with custom site name so the browser tab shows it immediately
result = injectSiteTitle(result, settingsJSON)
return result
}
// injectSiteTitle replaces the static <title> in HTML with the configured site name.
// This ensures the browser tab shows the correct title before JS executes.
func injectSiteTitle(html, settingsJSON []byte) []byte {
var cfg struct {
SiteName string `json:"site_name"`
}
if err := json.Unmarshal(settingsJSON, &cfg); err != nil || cfg.SiteName == "" {
return html
}
// Find and replace the existing <title>...</title>
titleStart := bytes.Index(html, []byte("<title>"))
titleEnd := bytes.Index(html, []byte("</title>"))
if titleStart == -1 || titleEnd == -1 || titleEnd <= titleStart {
return html
}
newTitle := []byte("<title>" + cfg.SiteName + " - AI API Gateway</title>")
var buf bytes.Buffer
buf.Write(html[:titleStart])
buf.Write(newTitle)
buf.Write(html[titleEnd+len("</title>"):])
return buf.Bytes()
} }
// replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value // replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value
......
...@@ -20,6 +20,78 @@ func init() { ...@@ -20,6 +20,78 @@ func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func TestInjectSiteTitle(t *testing.T) {
t.Run("replaces_title_with_site_name", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":"MyCustomSite"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Contains(t, string(result), "<title>MyCustomSite - AI API Gateway</title>")
assert.NotContains(t, string(result), "Sub2API")
})
t.Run("returns_unchanged_when_site_name_empty", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":""}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_site_name_missing", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{"other_field":"value"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_invalid_json", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{invalid json}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_no_title_tag", func(t *testing.T) {
html := []byte(`<html><head></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":"MyCustomSite"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_title_has_attributes", func(t *testing.T) {
// The function looks for "<title>" literally, so attributes are not supported
// This is acceptable since index.html uses plain <title> without attributes
html := []byte(`<html><head><title lang="en">Sub2API</title></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":"NewSite"}`)
result := injectSiteTitle(html, settingsJSON)
// Should return unchanged since <title> with attributes is not matched
assert.Equal(t, string(html), string(result))
})
t.Run("preserves_rest_of_html", func(t *testing.T) {
html := []byte(`<html><head><meta charset="UTF-8"><title>Sub2API</title><script src="app.js"></script></head><body><div id="app"></div></body></html>`)
settingsJSON := []byte(`{"site_name":"TestSite"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Contains(t, string(result), `<meta charset="UTF-8">`)
assert.Contains(t, string(result), `<script src="app.js"></script>`)
assert.Contains(t, string(result), `<div id="app"></div>`)
assert.Contains(t, string(result), "<title>TestSite - AI API Gateway</title>")
})
}
func TestReplaceNoncePlaceholder(t *testing.T) { func TestReplaceNoncePlaceholder(t *testing.T) {
t.Run("replaces_single_placeholder", func(t *testing.T) { t.Run("replaces_single_placeholder", func(t *testing.T) {
html := []byte(`<script nonce="__CSP_NONCE_VALUE__">console.log('test');</script>`) html := []byte(`<script nonce="__CSP_NONCE_VALUE__">console.log('test');</script>`)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment