Commit 538ae31a authored by 陈曦's avatar 陈曦
Browse files

merge v0.1.121 and fixed conflict

parents 74828a7c 48912014
Pipeline #82338 passed with stage
in 17 seconds
......@@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
// {"type":"auto"} → "auto"
// {"type":"any"} → "required"
// {"type":"none"} → "none"
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
// {"type":"tool","name":"X"} → {"type":"function","name":"X"}
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
var tc struct {
Type string `json:"type"`
......@@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
return json.Marshal("none")
case "tool":
return json.Marshal(map[string]any{
"type": "function",
"function": map[string]string{"name": tc.Name},
"type": "function",
"name": tc.Name,
})
default:
// Pass through unknown types as-is
......
......@@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"])
assert.Equal(t, "get_weather", tc["name"])
assert.NotContains(t, tc, "function")
}
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
......
......@@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R
//
// "auto" → "auto"
// "none" → "none"
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
// {"name":"X"} → {"type":"function","name":"X"}
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try string first ("auto", "none", etc.) — pass through as-is.
var s string
......@@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage,
return nil, err
}
return json.Marshal(map[string]any{
"type": "function",
"function": map[string]string{"name": obj.Name},
"type": "function",
"name": obj.Name,
})
}
......@@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
// "auto" → {"type":"auto"}
// "required" → {"type":"any"}
// "none" → {"type":"none"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
// {"type":"function","name":"X"} → {"type":"tool","name":"X"}
// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try as string first
var s string
......@@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage
// Try as object with type=function
var tc struct {
Type string `json:"type"`
Name string `json:"name"`
Function struct {
Name string `json:"name"`
} `json:"function"`
}
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" {
name := strings.TrimSpace(tc.Name)
if name == "" {
name = strings.TrimSpace(tc.Function.Name)
}
if name == "" {
return raw, nil
}
return json.Marshal(map[string]string{
"type": "tool",
"name": tc.Function.Name,
"name": name,
})
}
......
......@@ -2,16 +2,28 @@ package httputil
import (
"bytes"
"compress/gzip"
"compress/zlib"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/klauspost/compress/zstd"
)
const (
requestBodyReadInitCap = 512
requestBodyReadMaxInitCap = 1 << 20
// maxDecompressedBodySize limits the decompressed request body to 64 MB
// to prevent decompression bomb attacks.
maxDecompressedBodySize = 64 << 20
)
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based
// on content length, transparently decoding any Content-Encoding the upstream
// client used to compress the body (zstd, gzip, deflate).
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
......@@ -33,5 +45,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if _, err := io.Copy(buf, req.Body); err != nil {
return nil, err
}
return buf.Bytes(), nil
raw := buf.Bytes()
enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding")))
if enc == "" || enc == "identity" {
return raw, nil
}
decoded, err := decompressRequestBody(enc, raw)
if err != nil {
return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err)
}
req.Header.Del("Content-Encoding")
req.Header.Del("Content-Length")
req.ContentLength = int64(len(decoded))
return decoded, nil
}
func decompressRequestBody(encoding string, raw []byte) ([]byte, error) {
switch encoding {
case "zstd":
dec, err := zstd.NewReader(bytes.NewReader(raw))
if err != nil {
return nil, err
}
defer dec.Close()
return io.ReadAll(io.LimitReader(dec, maxDecompressedBodySize))
case "gzip", "x-gzip":
gr, err := gzip.NewReader(bytes.NewReader(raw))
if err != nil {
return nil, err
}
defer func() { _ = gr.Close() }()
return io.ReadAll(io.LimitReader(gr, maxDecompressedBodySize))
case "deflate":
zr, err := zlib.NewReader(bytes.NewReader(raw))
if err != nil {
return nil, err
}
defer func() { _ = zr.Close() }()
return io.ReadAll(io.LimitReader(zr, maxDecompressedBodySize))
default:
return nil, errors.New("unsupported Content-Encoding")
}
}
package httputil
import (
"bytes"
"compress/gzip"
"compress/zlib"
"net/http"
"strings"
"testing"
"github.com/klauspost/compress/zstd"
)
const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}`
func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request {
t.Helper()
req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
if encoding != "" {
req.Header.Set("Content-Encoding", encoding)
}
req.ContentLength = int64(len(body))
return req
}
func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) {
req := newRequestWithBody(t, []byte(samplePayload), "")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) {
enc, _ := zstd.NewWriter(nil)
compressed := enc.EncodeAll([]byte(samplePayload), nil)
_ = enc.Close()
req := newRequestWithBody(t, compressed, "zstd")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
if req.Header.Get("Content-Encoding") != "" {
t.Fatalf("Content-Encoding should be cleared after decoding")
}
if req.ContentLength != int64(len(samplePayload)) {
t.Fatalf("ContentLength not updated: %d", req.ContentLength)
}
}
func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) {
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
if _, err := gw.Write([]byte(samplePayload)); err != nil {
t.Fatalf("gzip write: %v", err)
}
if err := gw.Close(); err != nil {
t.Fatalf("gzip close: %v", err)
}
req := newRequestWithBody(t, buf.Bytes(), "gzip")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) {
var buf bytes.Buffer
zw := zlib.NewWriter(&buf)
if _, err := zw.Write([]byte(samplePayload)); err != nil {
t.Fatalf("zlib write: %v", err)
}
if err := zw.Close(); err != nil {
t.Fatalf("zlib close: %v", err)
}
req := newRequestWithBody(t, buf.Bytes(), "deflate")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) {
req := newRequestWithBody(t, []byte(samplePayload), "br")
_, err := ReadRequestBodyWithPrealloc(req)
if err == nil {
t.Fatal("expected error for unsupported encoding, got nil")
}
if !strings.Contains(err.Error(), "br") {
t.Fatalf("error should mention encoding, got %v", err)
}
}
func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) {
req := newRequestWithBody(t, []byte("not actually zstd"), "zstd")
_, err := ReadRequestBodyWithPrealloc(req)
if err == nil {
t.Fatal("expected error for corrupt zstd body, got nil")
}
}
func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Fatalf("expected nil body, got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) {
req := newRequestWithBody(t, []byte(samplePayload), "identity")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
......@@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi
return true, nil
}
func (s *schedulerCacheRecorder) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
return nil
}
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
......
......@@ -24,6 +24,49 @@ const (
defaultSchedulerSnapshotMGetChunkSize = 128
defaultSchedulerSnapshotWriteChunkSize = 256
// snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
// 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
snapshotGraceTTLSeconds = 60
)
var (
// activateSnapshotScript 原子 CAS 切换快照版本。
// 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
// 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。
//
// KEYS[1] = activeKey (sched:active:{bucket})
// KEYS[2] = readyKey (sched:ready:{bucket})
// KEYS[3] = bucketSetKey (sched:buckets)
// KEYS[4] = snapshotKey (新写入的快照 key)
// ARGV[1] = 新版本号字符串
// ARGV[2] = bucket 字符串 (用于 SADD)
// ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
// ARGV[4] = 宽限期 TTL 秒数
//
// 返回 1 = 已激活, 0 = 版本过旧未激活
activateSnapshotScript = redis.NewScript(`
local currentActive = redis.call('GET', KEYS[1])
local newVersion = tonumber(ARGV[1])
if currentActive ~= false then
local curVersion = tonumber(currentActive)
if curVersion and newVersion < curVersion then
redis.call('DEL', KEYS[4])
return 0
end
end
redis.call('SET', KEYS[1], ARGV[1])
redis.call('SET', KEYS[2], '1')
redis.call('SADD', KEYS[3], ARGV[2])
if currentActive ~= false and currentActive ~= ARGV[1] then
redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
end
return 1
`)
)
type schedulerCache struct {
......@@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
}
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
// Phase 1: 分配新版本号并写入快照数据。
// INCR 保证每个调用方获得唯一递增版本号。
// 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
version, err := c.rdb.Incr(ctx, versionKey).Result()
if err != nil {
......@@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
return err
}
pipe := c.rdb.Pipeline()
if len(accounts) > 0 {
// 使用序号作为 score,保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts))
......@@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
Member: strconv.FormatInt(account.ID, 10),
})
}
pipe := c.rdb.Pipeline()
for start := 0; start < len(members); start += c.writeChunkSize {
end := start + c.writeChunkSize
if end > len(members) {
......@@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
}
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
}
} else {
pipe.Del(ctx, snapshotKey)
}
pipe.Set(ctx, activeKey, versionStr, 0)
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
if _, err := pipe.Exec(ctx); err != nil {
return err
if _, err := pipe.Exec(ctx); err != nil {
return err
}
}
if oldActive != "" && oldActive != versionStr {
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
// Phase 2: 原子 CAS 激活版本。
// Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
// 防止并发写入导致版本回滚。
// 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode)
keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey}
args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds}
_, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result()
if err != nil {
return err
}
return nil
......@@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
}
func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
key := schedulerBucketKey(schedulerLockPrefix, bucket)
return c.rdb.Del(ctx, key).Err()
}
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
if err != nil {
......@@ -394,11 +449,69 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
SessionWindowStart: account.SessionWindowStart,
SessionWindowEnd: account.SessionWindowEnd,
SessionWindowStatus: account.SessionWindowStatus,
AccountGroups: filterSchedulerAccountGroups(account.AccountGroups),
GroupIDs: filterSchedulerGroupIDs(account.GroupIDs, account.AccountGroups),
Credentials: filterSchedulerCredentials(account.Credentials),
Extra: filterSchedulerExtra(account.Extra),
}
}
func filterSchedulerAccountGroups(accountGroups []service.AccountGroup) []service.AccountGroup {
if len(accountGroups) == 0 {
return nil
}
filtered := make([]service.AccountGroup, 0, len(accountGroups))
for _, ag := range accountGroups {
if ag.GroupID <= 0 {
continue
}
filtered = append(filtered, service.AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
})
}
if len(filtered) == 0 {
return nil
}
return filtered
}
func filterSchedulerGroupIDs(groupIDs []int64, accountGroups []service.AccountGroup) []int64 {
if len(groupIDs) == 0 && len(accountGroups) == 0 {
return nil
}
seen := make(map[int64]struct{}, len(groupIDs)+len(accountGroups))
filtered := make([]int64, 0, len(groupIDs)+len(accountGroups))
for _, id := range groupIDs {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
filtered = append(filtered, id)
}
for _, ag := range accountGroups {
if ag.GroupID <= 0 {
continue
}
if _, ok := seen[ag.GroupID]; ok {
continue
}
seen[ag.GroupID] = struct{}{}
filtered = append(filtered, ag.GroupID)
}
if len(filtered) == 0 {
return nil
}
return filtered
}
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
if len(credentials) == 0 {
return nil
......
......@@ -56,6 +56,15 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
SessionWindowStart: &now,
SessionWindowEnd: &windowEnd,
SessionWindowStatus: "active",
GroupIDs: []int64{bucket.GroupID},
AccountGroups: []service.AccountGroup{
{
AccountID: 101,
GroupID: bucket.GroupID,
Priority: 5,
Group: &service.Group{ID: bucket.GroupID, Name: "gemini-group"},
},
},
}
require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account}))
......@@ -79,10 +88,17 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
require.Equal(t, 4, got.GetMaxSessions())
require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes())
require.Nil(t, got.Extra["unused_large_field"])
require.Equal(t, []int64{bucket.GroupID}, got.GroupIDs)
require.Len(t, got.AccountGroups, 1)
require.Equal(t, account.ID, got.AccountGroups[0].AccountID)
require.Equal(t, bucket.GroupID, got.AccountGroups[0].GroupID)
require.Nil(t, got.AccountGroups[0].Group)
full, err := cache.GetAccount(ctx, account.ID)
require.NoError(t, err)
require.NotNil(t, full)
require.Equal(t, "secret-access-token", full.GetCredential("access_token"))
require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob"))
require.Len(t, full.AccountGroups, 1)
require.NotNil(t, full.AccountGroups[0].Group)
}
......@@ -31,3 +31,43 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
require.Equal(t, true, got.Extra["mixed_scheduling"])
require.Nil(t, got.Extra["unused_large_field"])
}
func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) {
account := service.Account{
ID: 42,
Platform: service.PlatformAnthropic,
GroupIDs: []int64{7, 9, 7, 0},
AccountGroups: []service.AccountGroup{
{
AccountID: 42,
GroupID: 7,
Priority: 2,
Account: &service.Account{ID: 42, Name: "drop-from-metadata"},
Group: &service.Group{ID: 7, Name: "drop-from-metadata"},
},
{
AccountID: 42,
GroupID: 11,
Priority: 3,
Group: &service.Group{ID: 11, Name: "drop-from-metadata"},
},
{
AccountID: 42,
GroupID: 0,
Priority: 4,
},
},
}
got := buildSchedulerMetadataAccount(account)
require.Equal(t, []int64{7, 9, 11}, got.GroupIDs)
require.Len(t, got.AccountGroups, 2)
require.Equal(t, int64(42), got.AccountGroups[0].AccountID)
require.Equal(t, int64(7), got.AccountGroups[0].GroupID)
require.Equal(t, 2, got.AccountGroups[0].Priority)
require.Nil(t, got.AccountGroups[0].Account)
require.Nil(t, got.AccountGroups[0].Group)
require.Equal(t, int64(11), got.AccountGroups[1].GroupID)
require.Nil(t, got.Groups)
}
......@@ -740,6 +740,7 @@ func TestAPIContracts(t *testing.T) {
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"enable_cch_signing": false,
"enable_anthropic_cache_ttl_1h_injection": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
......@@ -748,6 +749,16 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": true,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
......@@ -924,12 +935,23 @@ func TestAPIContracts(t *testing.T) {
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"enable_cch_signing": false,
"enable_anthropic_cache_ttl_1h_injection": false,
"web_search_emulation_enabled": false,
"payment_visible_method_alipay_source": "",
"payment_visible_method_wxpay_source": "",
"payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
......
......@@ -64,6 +64,7 @@ func isOpenAIImageModel(model string) bool {
type AccountTestService struct {
accountRepo AccountRepository
geminiTokenProvider *GeminiTokenProvider
claudeTokenProvider *ClaudeTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
......@@ -74,6 +75,7 @@ type AccountTestService struct {
func NewAccountTestService(
accountRepo AccountRepository,
geminiTokenProvider *GeminiTokenProvider,
claudeTokenProvider *ClaudeTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
cfg *config.Config,
......@@ -82,6 +84,7 @@ func NewAccountTestService(
return &AccountTestService{
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
claudeTokenProvider: claudeTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
......@@ -210,6 +213,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if account.IsBedrock() {
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
}
if account.Type == AccountTypeServiceAccount {
return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID)
}
// Determine authentication method and API URL
var authToken string
......@@ -313,6 +319,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.processClaudeStream(c, resp.Body)
}
func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
if mappedModel, matched := account.ResolveMappedModel(testModelID); matched {
testModelID = mappedModel
} else {
testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID))
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payload, err := createTestPayload(testModelID)
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create test payload")
}
payloadBytes, _ := json.Marshal(payload)
vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error()))
}
if s.claudeTokenProvider == nil {
return s.sendErrorAndEnd(c, "Claude token provider not configured")
}
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error()))
}
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error()))
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
if resp.StatusCode == http.StatusForbidden {
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, errMsg)
}
return s.processClaudeStream(c, resp.Body)
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account)
......@@ -711,8 +785,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
testModelID = geminicli.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == AccountTypeAPIKey {
// For static upstream credentials with model mapping, map the model
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
......@@ -740,6 +814,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
case AccountTypeServiceAccount:
req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload)
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
......@@ -893,6 +969,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
}
func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
if s.geminiTokenProvider == nil {
return nil, fmt.Errorf("gemini token provider not configured")
}
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("failed to get service account access token: %w", err)
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
return req, nil
}
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
var inner map[string]any
......@@ -1227,7 +1324,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations"
apiURL := buildOpenAIImagesURL(normalizedBaseURL, openAIImagesGenerationsEndpoint)
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
......
......@@ -8,6 +8,7 @@ import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
......@@ -48,3 +49,42 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
require.Contains(t, rec.Body.String(), "\"success\":true")
}
func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
},
Body: io.NopCloser(strings.NewReader(`{"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{},
}
account := &Account{
ID: 54,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-api-key",
"base_url": "https://image-upstream.example/v1",
},
}
err := svc.testOpenAIImageAPIKey(c, context.Background(), account, "gpt-image-2", "draw a cat")
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
require.Contains(t, rec.Body.String(), "\"success\":true")
}
......@@ -9,6 +9,7 @@ import (
"log/slog"
"net/http"
"sort"
"strconv"
"strings"
"time"
......@@ -58,6 +59,7 @@ type AdminService interface {
// API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error)
AdminSetCaptureRequests(ctx context.Context, keyID int64, enabled bool) (*APIKey, error)
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
......@@ -292,6 +294,7 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Filters *BulkUpdateAccountFilters
Name string
ProxyID *int64
Concurrency *int
......@@ -308,6 +311,15 @@ type BulkUpdateAccountsInput struct {
SkipMixedChannelCheck bool
}
type BulkUpdateAccountFilters struct {
Platform string
Type string
Status string
Group string
Search string
PrivacyMode string
}
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
......@@ -1962,6 +1974,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return result, nil
}
// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows.
func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
if err != nil {
return nil, err
}
apiKey.Usage5h = 0
apiKey.Usage1d = 0
apiKey.Usage7d = 0
apiKey.Window5hStart = nil
apiKey.Window1dStart = nil
apiKey.Window7dStart = nil
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("reset api key rate limit usage: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
}
return apiKey, nil
}
// AdminSetCaptureRequests 设置或清除指定 API Key 的请求捕获开关,并立即失效认证缓存。
func (s *adminServiceImpl) AdminSetCaptureRequests(ctx context.Context, keyID int64, enabled bool) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
......@@ -2303,6 +2339,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
if len(input.AccountIDs) == 0 && input.Filters != nil {
accountIDs, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters)
if err != nil {
return nil, err
}
input.AccountIDs = accountIDs
}
result := &BulkUpdateAccountsResult{
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
......@@ -2418,6 +2462,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil
}
func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) {
if filters == nil {
return nil, nil
}
groupID := int64(0)
switch strings.TrimSpace(filters.Group) {
case "":
case "ungrouped":
groupID = AccountListGroupUngrouped
default:
parsedGroupID, err := strconv.ParseInt(strings.TrimSpace(filters.Group), 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid group filter: %w", err)
}
groupID = parsedGroupID
}
const pageSize = 500
page := 1
accountIDs := make([]int64, 0, pageSize)
for {
accounts, total, err := s.ListAccounts(
ctx,
page,
pageSize,
filters.Platform,
filters.Type,
filters.Status,
filters.Search,
groupID,
filters.PrivacyMode,
"",
"",
)
if err != nil {
return nil, err
}
for _, account := range accounts {
accountIDs = append(accountIDs, account.ID)
}
if int64(len(accountIDs)) >= total || len(accounts) == 0 {
return accountIDs, nil
}
page++
}
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
if err := s.accountRepo.Delete(ctx, id); err != nil {
return err
......
......@@ -5,8 +5,10 @@ package service
import (
"context"
"errors"
"reflect"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
......@@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct {
getByIDCalled []int64
listByGroupData map[int64][]Account
listByGroupErr map[int64]error
listData []Account
listResult *pagination.PaginationResult
listErr error
listCalled bool
lastListParams pagination.PaginationParams
lastListFilters struct {
platform string
accountType string
status string
search string
groupID int64
privacyMode string
}
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
......@@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in
return nil, nil
}
func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
s.listCalled = true
s.lastListParams = params
s.lastListFilters.platform = platform
s.lastListFilters.accountType = accountType
s.lastListFilters.status = status
s.lastListFilters.search = search
s.lastListFilters.groupID = groupID
s.lastListFilters.privacyMode = privacyMode
if s.listErr != nil {
return nil, nil, s.listErr
}
if s.listResult != nil {
return s.listData, s.listResult, nil
}
return s.listData, &pagination.PaginationResult{Total: int64(len(s.listData))}, nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
......@@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon
// No BindGroups should have been called since the check runs before any write.
require.Empty(t, repo.bindGroupsCalls)
}
func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
listData: []Account{
{ID: 7},
{ID: 11},
},
listResult: &pagination.PaginationResult{Total: 2},
}
svc := &adminServiceImpl{accountRepo: repo}
schedulable := true
input := &BulkUpdateAccountsInput{
Schedulable: &schedulable,
}
filtersField := reflect.ValueOf(input).Elem().FieldByName("Filters")
require.True(t, filtersField.IsValid(), "BulkUpdateAccountsInput should expose Filters for filter-target bulk update")
require.Equal(t, reflect.Ptr, filtersField.Kind(), "BulkUpdateAccountsInput.Filters should be a pointer field")
filtersValue := reflect.New(filtersField.Type().Elem())
filtersValue.Elem().FieldByName("Platform").SetString(PlatformOpenAI)
filtersValue.Elem().FieldByName("Type").SetString(AccountTypeOAuth)
filtersValue.Elem().FieldByName("Status").SetString(StatusActive)
filtersValue.Elem().FieldByName("Group").SetString("12")
filtersValue.Elem().FieldByName("PrivacyMode").SetString(PrivacyModeCFBlocked)
filtersValue.Elem().FieldByName("Search").SetString("bulk-target")
filtersField.Set(filtersValue)
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.True(t, repo.listCalled, "expected filter-target bulk update to resolve matching IDs via account list filters")
require.Equal(t, PlatformOpenAI, repo.lastListFilters.platform)
require.Equal(t, AccountTypeOAuth, repo.lastListFilters.accountType)
require.Equal(t, StatusActive, repo.lastListFilters.status)
require.Equal(t, "bulk-target", repo.lastListFilters.search)
require.Equal(t, int64(12), repo.lastListFilters.groupID)
require.Equal(t, PrivacyModeCFBlocked, repo.lastListFilters.privacyMode)
require.Equal(t, []int64{7, 11}, repo.bulkUpdateIDs)
require.Equal(t, 2, result.Success)
require.Equal(t, 0, result.Failed)
require.Equal(t, []int64{7, 11}, result.SuccessIDs)
}
......@@ -508,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
return nil
}
// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key.
func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
if s.cache == nil {
return nil
}
if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err)
return err
}
return nil
}
// ============================================
// API Key 限速缓存方法
// ============================================
......
......@@ -17,7 +17,7 @@ const (
// ClaudeTokenCache token cache interface.
type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
// ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts.
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
......@@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
return "", errors.New("not an anthropic oauth or service account")
}
if account.Type == AccountTypeServiceAccount {
return p.getServiceAccountAccessToken(ctx, account)
}
cacheKey := ClaudeTokenCacheKey(account)
......@@ -157,3 +160,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
}
......@@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
return "", errors.New("not an anthropic oauth or service account")
}
cacheKey := ClaudeTokenCacheKey(account)
......@@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
......@@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
......@@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
......
......@@ -41,11 +41,12 @@ const (
// Account type constants
const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI)
)
// Redeem type constants
......@@ -306,6 +307,12 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
// SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
// service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
// targets OpenAI's body-level service_tier field instead of Claude's
// anthropic-beta header.
SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
// =========================
// Claude Code Version Check
// =========================
......@@ -329,6 +336,8 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning = "enable_cch_signing"
// SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
......
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Authorization", "Bearer inbound-token")
c.Request.Header.Set("X-Api-Key", "inbound-api-key")
c.Request.Header.Set("Anthropic-Version", "2023-06-01")
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
account := &Account{
ID: 301,
Platform: PlatformAnthropic,
Type: AccountTypeServiceAccount,
Credentials: map[string]any{
"project_id": "vertex-proj",
"location": "us-east5",
},
}
body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`)
svc := &GatewayService{}
req, err := svc.buildUpstreamRequest(
context.Background(),
c,
account,
body,
"vertex-token",
"service_account",
"claude-sonnet-4-5@20250929",
false,
false,
)
require.NoError(t, err)
require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String())
require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization"))
require.Empty(t, getHeaderRaw(req.Header, "x-api-key"))
require.Empty(t, getHeaderRaw(req.Header, "anthropic-version"))
require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta"))
got := readRequestBodyForTest(t, req)
require.Equal(t, "", gjson.GetBytes(got, "model").String())
require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String())
}
func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
t.Helper()
require.NotNil(t, req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
return body
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment