Unverified Commit 6bccb8a8 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge branch 'main' into feature/antigravity-user-agent-configurable

parents 1fc6ef3d 3de1e0e4
//go:build unit
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Unit tests for TLS fingerprint dialer.
......@@ -9,26 +11,161 @@
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
type FingerprintResponse struct {
IP string `json:"ip"`
TLS TLSInfo `json:"tls"`
HTTP2 any `json:"http2"`
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) {
skipNetworkTest(t)
// Create a dialer with default profile
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
// Create HTTP client with custom TLS dialer
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Make a request to a known HTTPS endpoint
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
skipNetworkTest(t)
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TLSInfo contains TLS fingerprint details.
type TLSInfo struct {
JA3 string `json:"ja3"`
JA3Hash string `json:"ja3_hash"`
JA4 string `json:"ja4"`
PeetPrint string `json:"peetprint"`
PeetPrintHash string `json:"peetprint_hash"`
ClientRandom string `json:"client_random"`
SessionID string `json:"session_id"`
func skipNetworkTest(t *testing.T) {
if testing.Short() {
t.Skip("跳过网络测试(short 模式)")
}
if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" {
t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
......@@ -158,3 +295,137 @@ func mustParseURL(rawURL string) *url.URL {
}
return u
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
skipNetworkTest(t)
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
return nil
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}
package tlsfingerprint
// FingerprintResponse represents the response from tls.peet.ws/api/all.
// 共享测试类型,供 unit 和 integration 测试文件使用。
type FingerprintResponse struct {
IP string `json:"ip"`
TLS TLSInfo `json:"tls"`
HTTP2 any `json:"http2"`
}
// TLSInfo contains TLS fingerprint details.
type TLSInfo struct {
JA3 string `json:"ja3"`
JA3Hash string `json:"ja3_hash"`
JA4 string `json:"ja4"`
PeetPrint string `json:"peetprint"`
PeetPrintHash string `json:"peetprint_hash"`
ClientRandom string `json:"client_random"`
SessionID string `json:"session_id"`
}
......@@ -15,7 +15,6 @@ import (
"database/sql"
"encoding/json"
"errors"
"log"
"strconv"
"time"
......@@ -25,6 +24,7 @@ import (
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
......@@ -127,7 +127,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
}
return nil
}
......@@ -388,7 +388,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
}
account.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
......@@ -429,7 +429,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -533,7 +533,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
},
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -568,7 +568,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
}
payload := map[string]any{"last_used": lastUsedPayload}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
}
return nil
}
......@@ -583,7 +583,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
......@@ -603,11 +603,11 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
account, err := r.GetByID(ctx, accountID)
if err != nil {
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
return
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
}
}
......@@ -631,7 +631,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
......@@ -648,7 +648,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
......@@ -721,7 +721,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
}
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
}
return nil
}
......@@ -829,7 +829,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -876,7 +876,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -890,7 +890,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -909,7 +909,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
......@@ -928,7 +928,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -944,7 +944,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -968,7 +968,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -992,7 +992,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64)
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -1014,7 +1014,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
// 触发调度器缓存更新(仅当窗口时间有变化时)
if start != nil || end != nil {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
}
}
return nil
......@@ -1029,7 +1029,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
if !schedulable {
r.syncSchedulerAccountSnapshot(ctx, id)
......@@ -1057,7 +1057,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
}
if rows > 0 {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
}
}
return rows, nil
......@@ -1093,7 +1093,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
}
return nil
}
......@@ -1187,7 +1187,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if rows > 0 {
payload := map[string]any{"account_ids": ids}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
shouldSync := false
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
......@@ -1560,3 +1560,64 @@ func joinClauses(clauses []string, sep string) string {
func itoa(v int) string {
return strconv.Itoa(v)
}
// FindByExtraField 根据 extra 字段中的键值对查找账号。
// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
//
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
//
// FindByExtraField finds accounts by key-value pairs in the extra field.
// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
//
// Use case: Finding Sora accounts linked via linked_openai_account_id.
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
dbaccount.DeletedAtIsNil(),
func(s *entsql.Selector) {
path := sqljson.Path(key)
switch v := value.(type) {
case string:
preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)}
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path))
}
if len(preds) == 1 {
s.Where(preds[0])
} else {
s.Where(entsql.Or(preds...))
}
case int:
s.Where(entsql.Or(
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path),
))
case int64:
s.Where(entsql.Or(
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path),
))
case json.Number:
if parsed, err := v.Int64(); err == nil {
s.Where(entsql.Or(
sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path),
sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path),
))
} else {
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path))
}
default:
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path))
}
},
).
All(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
return r.accountsToService(ctx, accounts)
}
......@@ -34,6 +34,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID).
SetNillableLastUsedAt(key.LastUsedAt).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetNillableExpiresAt(key.ExpiresAt)
......@@ -48,6 +49,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
created, err := builder.Save(ctx)
if err == nil {
key.ID = created.ID
key.LastUsedAt = created.LastUsedAt
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
......@@ -140,6 +142,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldSoraImagePrice360,
group.FieldSoraImagePrice540,
group.FieldSoraVideoPricePerRequest,
group.FieldSoraVideoPricePerRequestHd,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
......@@ -375,36 +381,34 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return keys, nil
}
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldQuotaUsed).
Only(ctx)
updated, err := r.client.APIKey.UpdateOneID(id).
Where(apikey.DeletedAtIsNil()).
AddQuotaUsed(amount).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
}
return 0, err
}
return updated.QuotaUsed, nil
}
newValue := m.QuotaUsed + amount
// Update with new value
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetQuotaUsed(newValue).
SetLastUsedAt(usedAt).
SetUpdatedAt(usedAt).
Save(ctx)
if err != nil {
return 0, err
return err
}
if affected == 0 {
return 0, service.ErrAPIKeyNotFound
return service.ErrAPIKeyNotFound
}
return newValue, nil
return nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
......@@ -419,6 +423,7 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
......@@ -477,6 +482,10 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
......
......@@ -4,11 +4,14 @@ package repository
import (
"context"
"sync"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
......@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
return k
}
// --- IncrementQuotaUsed ---
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
user := s.mustCreateUser("incr-basic@test.com")
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
s.Require().NoError(err, "IncrementQuotaUsed")
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
s.Require().NoError(err, "IncrementQuotaUsed second")
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
user := s.mustCreateUser("incr-deleted@test.com")
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
}
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
client := testEntClient(t)
repo := NewAPIKeyRepository(client).(*apiKeyRepository)
ctx := context.Background()
// 创建测试用户和 API Key
u, err := client.User.Create().
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
SetPasswordHash("hash").
SetStatus(service.StatusActive).
SetRole(service.RoleUser).
Save(ctx)
require.NoError(t, err, "create user")
k := &service.APIKey{
UserID: u.ID,
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
Name: "Concurrent",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, k), "create api key")
t.Cleanup(func() {
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
})
// 10 个 goroutine 各递增 1.0,总计应为 10.0
const goroutines = 10
const increment = 1.0
var wg sync.WaitGroup
errs := make([]error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
}(i)
}
wg.Wait()
for i, e := range errs {
require.NoError(t, e, "goroutine %d failed", i)
}
// 验证最终结果
got, err := repo.GetByID(ctx, k.ID)
require.NoError(t, err, "GetByID")
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
"并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
}
package repository
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newAPIKeyRepoSQLite(t *testing.T) (*apiKeyRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:api_key_repo_last_used?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return &apiKeyRepository{client: client}, client
}
func mustCreateAPIKeyRepoUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *service.User {
t.Helper()
u, err := client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
return userEntityToService(u)
}
func TestAPIKeyRepository_CreateWithLastUsedAt(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "create-last-used@test.com")
lastUsed := time.Now().UTC().Add(-time.Hour).Truncate(time.Second)
key := &service.APIKey{
UserID: user.ID,
Key: "sk-create-last-used",
Name: "CreateWithLastUsed",
Status: service.StatusActive,
LastUsedAt: &lastUsed,
}
require.NoError(t, repo.Create(ctx, key))
require.NotNil(t, key.LastUsedAt)
require.WithinDuration(t, lastUsed, *key.LastUsedAt, time.Second)
got, err := repo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.NotNil(t, got.LastUsedAt)
require.WithinDuration(t, lastUsed, *got.LastUsedAt, time.Second)
}
func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "update-last-used@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update-last-used",
Name: "UpdateLastUsed",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
before, err := repo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.Nil(t, before.LastUsedAt)
target := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second)
require.NoError(t, repo.UpdateLastUsed(ctx, key.ID, target))
after, err := repo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.NotNil(t, after.LastUsedAt)
require.WithinDuration(t, target, *after.LastUsedAt, time.Second)
require.WithinDuration(t, target, after.UpdatedAt, time.Second)
}
func TestAPIKeyRepository_UpdateLastUsedDeletedKey(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "deleted-last-used@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update-last-used-deleted",
Name: "UpdateLastUsedDeleted",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
require.NoError(t, repo.Delete(ctx, key.ID))
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
require.ErrorIs(t, err, service.ErrAPIKeyNotFound)
}
func TestAPIKeyRepository_UpdateLastUsedDBError(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "db-error-last-used@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update-last-used-db-error",
Name: "UpdateLastUsedDBError",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
require.NoError(t, client.Close())
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
require.Error(t, err)
}
func TestAPIKeyRepository_CreateDuplicateKey(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "duplicate-key@test.com")
first := &service.APIKey{
UserID: user.ID,
Key: "sk-duplicate",
Name: "first",
Status: service.StatusActive,
}
second := &service.APIKey{
UserID: user.ID,
Key: "sk-duplicate",
Name: "second",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, first))
err := repo.Create(ctx, second)
require.ErrorIs(t, err, service.ErrAPIKeyExists)
}
......@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"math/rand/v2"
"strconv"
"time"
......@@ -16,8 +17,19 @@ const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
)
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
func jitteredTTL() time.Duration {
// 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。
if billingCacheJitter <= 0 {
return billingCacheTTL
}
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
return billingCacheTTL - jitter
}
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
......@@ -82,14 +94,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
return err
}
return nil
}
......@@ -163,16 +176,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
pipe.Expire(ctx, key, jitteredTTL())
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
return err
}
return nil
}
......
......@@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}
}
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
tests := []struct {
name string
fn func(ctx context.Context, cache service.BillingCache)
expectErr bool
}{
{
name: "key_not_exists_returns_nil",
fn: func(ctx context.Context, cache service.BillingCache) {
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
err := cache.DeductUserBalance(ctx, 99999, 1.0)
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
},
},
{
name: "existing_key_deducts_successfully",
fn: func(ctx context.Context, cache service.BillingCache) {
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
err := cache.DeductUserBalance(ctx, 200, 10.0)
require.NoError(s.T(), err, "DeductUserBalance should succeed")
bal, err := cache.GetUserBalance(ctx, 200)
require.NoError(s.T(), err)
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
},
},
{
name: "cancelled_context_propagates_error",
fn: func(ctx context.Context, cache service.BillingCache) {
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
cancelCtx, cancel := context.WithCancel(ctx)
cancel() // 立即取消
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
require.Error(s.T(), err, "cancelled context should propagate error")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, cache)
})
}
}
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
s.Run("key_not_exists_returns_nil", func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
})
s.Run("cancelled_context_propagates_error", func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
require.Error(s.T(), err, "cancelled context should propagate error")
})
}
func TestBillingCacheSuite(t *testing.T) {
suite.Run(t, new(BillingCacheSuite))
}
package repository
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 ---
func TestJitteredTTL_WithinExpectedRange(t *testing.T) {
// jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter)
// 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内
lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s
upperBound := billingCacheTTL // 5min
for i := 0; i < 200; i++ {
ttl := jitteredTTL()
assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound),
"TTL 不应低于 %v,实际得到 %v", lowerBound, ttl)
assert.LessOrEqual(t, int64(ttl), int64(upperBound),
"TTL 不应超过 %v(上界不变保证),实际得到 %v", upperBound, ttl)
}
}
func TestJitteredTTL_NeverExceedsBase(t *testing.T) {
// 关键安全性测试:jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL
for i := 0; i < 500; i++ {
ttl := jitteredTTL()
assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL),
"jitteredTTL 不应超过基础 TTL(上界预期不被打破)")
}
}
func TestJitteredTTL_HasVariance(t *testing.T) {
// 验证抖动确实产生了不同的值
results := make(map[time.Duration]bool)
for i := 0; i < 100; i++ {
ttl := jitteredTTL()
results[ttl] = true
}
require.Greater(t, len(results), 1,
"jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同")
}
func TestJitteredTTL_AverageNearCenter(t *testing.T) {
// 验证平均值大约在抖动范围中间
var sum time.Duration
runs := 1000
for i := 0; i < runs; i++ {
sum += jitteredTTL()
}
avg := sum / time.Duration(runs)
expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s
// 允许 ±5s 的误差
tolerance := 5 * time.Second
assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance),
"平均 TTL 应接近抖动范围中心 %v", expectedCenter)
}
func TestBillingKeyGeneration(t *testing.T) {
t.Run("balance_key", func(t *testing.T) {
key := billingBalanceKey(12345)
assert.Equal(t, "billing:balance:12345", key)
})
t.Run("sub_key", func(t *testing.T) {
key := billingSubKey(100, 200)
assert.Equal(t, "billing:sub:100:200", key)
})
}
func BenchmarkJitteredTTL(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = jitteredTTL()
}
}
......@@ -5,6 +5,7 @@ package repository
import (
"math"
"testing"
"time"
"github.com/stretchr/testify/require"
)
......@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
})
}
}
func TestJitteredTTL(t *testing.T) {
const (
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
)
for i := 0; i < 200; i++ {
ttl := jitteredTTL()
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
}
}
func TestJitteredTTL_HasVariation(t *testing.T) {
// 多次调用应该产生不同的值(验证抖动存在)
seen := make(map[time.Duration]struct{}, 50)
for i := 0; i < 50; i++ {
seen[jitteredTTL()] = struct{}{}
}
// 50 次调用中应该至少有 2 个不同的值
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
}
......@@ -4,12 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
......@@ -41,7 +41,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
}
targetURL := s.baseURL + "/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
......@@ -53,11 +53,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
......@@ -69,21 +69,21 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
// 如果只有一个组织,直接使用
if len(orgs) == 1 {
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
for _, org := range orgs {
if org.RavenType != nil && *org.RavenType == "team" {
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
org.UUID, org.Name, *org.RavenType)
return org.UUID, nil
}
}
// 如果没有 team 类型的组织,使用第一个
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
......@@ -103,9 +103,9 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256",
}
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2: Getting authorization code from %s", authURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
RedirectURI string `json:"redirect_uri"`
......@@ -128,11 +128,11 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
......@@ -160,7 +160,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil
}
......@@ -192,9 +192,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
}
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
......@@ -208,17 +208,17 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
Post(s.tokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 SUCCESS - Got access token")
return &tokenResp, nil
}
......
......@@ -147,100 +147,6 @@ var (
return 1
`)
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
// ARGV[1] = slot TTL (seconds)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local slotTTL = tonumber(ARGV[1])
-- Get current server time
local timeResult = redis.call('TIME')
local nowSeconds = tonumber(timeResult[1])
local cutoffTime = nowSeconds - slotTTL
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
-- Clean up expired slots before counting
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// getUsersLoadBatchScript - batch load query for users with expired slot cleanup
// ARGV[1] = slot TTL (seconds)
// ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ...
getUsersLoadBatchScript = redis.NewScript(`
local result = {}
local slotTTL = tonumber(ARGV[1])
-- Get current server time
local timeResult = redis.call('TIME')
local nowSeconds = tonumber(timeResult[1])
local cutoffTime = nowSeconds - slotTTL
local i = 2
while i <= #ARGV do
local userID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:user:' .. userID
-- Clean up expired slots before counting
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'concurrency:wait:' .. userID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, userID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
......@@ -399,29 +305,53 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster(Lua 内动态拼 key 会 CROSSSLOT)。
// 每个账号执行 3 个命令:ZREMRANGEBYSCORE(清理过期)、ZCARD(并发数)、GET(等待数)。
now, err := c.rdb.Time(ctx).Result()
if err != nil {
return nil, err
return nil, fmt.Errorf("redis TIME: %w", err)
}
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
pipe := c.rdb.Pipeline()
type accountCmds struct {
id int64
maxConcurrency int
zcardCmd *redis.IntCmd
getCmd *redis.StringCmd
}
cmds := make([]accountCmds, 0, len(accounts))
for _, acc := range accounts {
slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10)
waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10)
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
ac := accountCmds{
id: acc.ID,
maxConcurrency: acc.MaxConcurrency,
zcardCmd: pipe.ZCard(ctx, slotKey),
getCmd: pipe.Get(ctx, waitKey),
}
cmds = append(cmds, ac)
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("pipeline exec: %w", err)
}
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts))
for _, ac := range cmds {
currentConcurrency := int(ac.zcardCmd.Val())
waitingCount := 0
if v, err := ac.getCmd.Int(); err == nil {
waitingCount = v
}
loadRate := 0
if ac.maxConcurrency > 0 {
loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency
}
loadMap[ac.id] = &service.AccountLoadInfo{
AccountID: ac.id,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
......@@ -436,29 +366,52 @@ func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []servic
return map[int64]*service.UserLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, u := range users {
args = append(args, u.ID, u.MaxConcurrency)
}
result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。
now, err := c.rdb.Time(ctx).Result()
if err != nil {
return nil, err
return nil, fmt.Errorf("redis TIME: %w", err)
}
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
loadMap := make(map[int64]*service.UserLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
pipe := c.rdb.Pipeline()
type userCmds struct {
id int64
maxConcurrency int
zcardCmd *redis.IntCmd
getCmd *redis.StringCmd
}
cmds := make([]userCmds, 0, len(users))
for _, u := range users {
slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10)
waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10)
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
uc := userCmds{
id: u.ID,
maxConcurrency: u.MaxConcurrency,
zcardCmd: pipe.ZCard(ctx, slotKey),
getCmd: pipe.Get(ctx, waitKey),
}
cmds = append(cmds, uc)
}
userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("pipeline exec: %w", err)
}
loadMap[userID] = &service.UserLoadInfo{
UserID: userID,
loadMap := make(map[int64]*service.UserLoadInfo, len(users))
for _, uc := range cmds {
currentConcurrency := int(uc.zcardCmd.Val())
waitingCount := 0
if v, err := uc.getCmd.Int(); err == nil {
waitingCount = v
}
loadRate := 0
if uc.maxConcurrency > 0 {
loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency
}
loadMap[uc.id] = &service.UserLoadInfo{
UserID: uc.id,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
......
......@@ -5,6 +5,7 @@ package repository
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/ent"
......@@ -66,6 +67,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client := ent.NewClient(ent.Driver(drv))
// 启动阶段:从配置或数据库中确保系统密钥可用。
if err := ensureBootstrapSecrets(migrationCtx, client, cfg); err != nil {
_ = client.Close()
return nil, nil, err
}
// 在密钥补齐后执行完整配置校验,避免空 jwt.secret 导致服务运行时失败。
if err := cfg.Validate(); err != nil {
_ = client.Close()
return nil, nil, fmt.Errorf("validate config after secret bootstrap: %w", err)
}
// SIMPLE 模式:启动时补齐各平台默认分组。
// - anthropic/openai/gemini: 确保存在 <platform>-default
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
......
......@@ -18,14 +18,21 @@ type githubReleaseClient struct {
downloadHTTPClient *http.Client
}
type githubReleaseClientError struct {
err error
}
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
if proxyURL != "" && !allowDirectOnProxyError {
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
......@@ -35,6 +42,9 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
ProxyURL: proxyURL,
})
if err != nil {
if proxyURL != "" && !allowDirectOnProxyError {
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
......@@ -44,6 +54,18 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
}
}
func (c *githubReleaseClientError) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
return nil, c.err
}
func (c *githubReleaseClientError) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
return c.err
}
func (c *githubReleaseClientError) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
return nil, c.err
}
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
......
......@@ -4,11 +4,11 @@ import (
"context"
"database/sql"
"errors"
"log"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
......@@ -47,6 +47,10 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
......@@ -68,7 +72,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
groupIn.CreatedAt = created.CreatedAt
groupIn.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
}
}
return translatePersistenceError(err, nil, service.ErrGroupExists)
......@@ -110,6 +114,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
......@@ -144,7 +152,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
}
groupIn.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
}
return nil
}
......@@ -155,7 +163,7 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
}
return nil
}
......@@ -183,7 +191,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
q = q.Where(group.IsExclusiveEQ(*isExclusive))
}
total, err := q.Count(ctx)
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
......@@ -288,7 +296,7 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
}
affected, _ := res.RowsAffected()
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
}
return affected, nil
}
......@@ -398,7 +406,7 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
}
return affectedUserIDs, nil
......@@ -492,7 +500,7 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64
// 发送调度器事件
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
}
return nil
......
package repository
import (
"context"
"database/sql"
"errors"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type idempotencyRepository struct {
sql sqlExecutor
}
func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository {
return &idempotencyRepository{sql: sqlDB}
}
func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) {
if record == nil {
return false, nil
}
query := `
INSERT INTO idempotency_records (
scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at
) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (scope, idempotency_key_hash) DO NOTHING
RETURNING id, created_at, updated_at
`
var createdAt time.Time
var updatedAt time.Time
err := scanSingleRow(ctx, r.sql, query, []any{
record.Scope,
record.IdempotencyKeyHash,
record.RequestFingerprint,
record.Status,
record.LockedUntil,
record.ExpiresAt,
}, &record.ID, &createdAt, &updatedAt)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, err
}
record.CreatedAt = createdAt
record.UpdatedAt = updatedAt
return true, nil
}
func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
query := `
SELECT
id, scope, idempotency_key_hash, request_fingerprint, status, response_status,
response_body, error_reason, locked_until, expires_at, created_at, updated_at
FROM idempotency_records
WHERE scope = $1 AND idempotency_key_hash = $2
`
record := &service.IdempotencyRecord{}
var responseStatus sql.NullInt64
var responseBody sql.NullString
var errorReason sql.NullString
var lockedUntil sql.NullTime
err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash},
&record.ID,
&record.Scope,
&record.IdempotencyKeyHash,
&record.RequestFingerprint,
&record.Status,
&responseStatus,
&responseBody,
&errorReason,
&lockedUntil,
&record.ExpiresAt,
&record.CreatedAt,
&record.UpdatedAt,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
if responseStatus.Valid {
v := int(responseStatus.Int64)
record.ResponseStatus = &v
}
if responseBody.Valid {
v := responseBody.String
record.ResponseBody = &v
}
if errorReason.Valid {
v := errorReason.String
record.ErrorReason = &v
}
if lockedUntil.Valid {
v := lockedUntil.Time
record.LockedUntil = &v
}
return record, nil
}
func (r *idempotencyRepository) TryReclaim(
ctx context.Context,
id int64,
fromStatus string,
now, newLockedUntil, newExpiresAt time.Time,
) (bool, error) {
query := `
UPDATE idempotency_records
SET status = $2,
locked_until = $3,
error_reason = NULL,
updated_at = NOW(),
expires_at = $4
WHERE id = $1
AND status = $5
AND (locked_until IS NULL OR locked_until <= $6)
`
res, err := r.sql.ExecContext(ctx, query,
id,
service.IdempotencyStatusProcessing,
newLockedUntil,
newExpiresAt,
fromStatus,
now,
)
if err != nil {
return false, err
}
affected, err := res.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *idempotencyRepository) ExtendProcessingLock(
ctx context.Context,
id int64,
requestFingerprint string,
newLockedUntil,
newExpiresAt time.Time,
) (bool, error) {
query := `
UPDATE idempotency_records
SET locked_until = $2,
expires_at = $3,
updated_at = NOW()
WHERE id = $1
AND status = $4
AND request_fingerprint = $5
`
res, err := r.sql.ExecContext(
ctx,
query,
id,
newLockedUntil,
newExpiresAt,
service.IdempotencyStatusProcessing,
requestFingerprint,
)
if err != nil {
return false, err
}
affected, err := res.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
query := `
UPDATE idempotency_records
SET status = $2,
response_status = $3,
response_body = $4,
error_reason = NULL,
locked_until = NULL,
expires_at = $5,
updated_at = NOW()
WHERE id = $1
`
_, err := r.sql.ExecContext(ctx, query,
id,
service.IdempotencyStatusSucceeded,
responseStatus,
responseBody,
expiresAt,
)
return err
}
func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
query := `
UPDATE idempotency_records
SET status = $2,
error_reason = $3,
locked_until = $4,
expires_at = $5,
updated_at = NOW()
WHERE id = $1
`
_, err := r.sql.ExecContext(ctx, query,
id,
service.IdempotencyStatusFailedRetryable,
errorReason,
lockedUntil,
expiresAt,
)
return err
}
func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) {
if limit <= 0 {
limit = 500
}
query := `
WITH victims AS (
SELECT id
FROM idempotency_records
WHERE expires_at <= $1
ORDER BY expires_at ASC
LIMIT $2
)
DELETE FROM idempotency_records
WHERE id IN (SELECT id FROM victims)
`
res, err := r.sql.ExecContext(ctx, query, now, limit)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
//go:build integration
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// hashedTestValue returns a unique SHA-256 hex string (64 chars) that fits VARCHAR(64) columns.
func hashedTestValue(t *testing.T, prefix string) string {
t.Helper()
sum := sha256.Sum256([]byte(uniqueTestValue(t, prefix)))
return hex.EncodeToString(sum[:])
}
func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-create"),
IdempotencyKeyHash: hashedTestValue(t, "idem-hash"),
RequestFingerprint: hashedTestValue(t, "idem-fp"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(30 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NotZero(t, record.ID)
duplicate := &service.IdempotencyRecord{
Scope: record.Scope,
IdempotencyKeyHash: record.IdempotencyKeyHash,
RequestFingerprint: hashedTestValue(t, "idem-fp-other"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(30 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err = repo.CreateProcessing(ctx, duplicate)
require.NoError(t, err)
require.False(t, owner, "same scope+key hash should be de-duplicated")
}
func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-reclaim"),
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-reclaim"),
RequestFingerprint: hashedTestValue(t, "idem-fp-reclaim"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(10 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NoError(t, repo.MarkFailedRetryable(
ctx,
record.ID,
"RETRYABLE_FAILURE",
now.Add(-2*time.Second),
now.Add(24*time.Hour),
))
newLockedUntil := now.Add(20 * time.Second)
reclaimed, err := repo.TryReclaim(
ctx,
record.ID,
service.IdempotencyStatusFailedRetryable,
now,
newLockedUntil,
now.Add(24*time.Hour),
)
require.NoError(t, err)
require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim")
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, service.IdempotencyStatusProcessing, got.Status)
require.NotNil(t, got.LockedUntil)
require.True(t, got.LockedUntil.After(now))
require.NoError(t, repo.MarkFailedRetryable(
ctx,
record.ID,
"RETRYABLE_FAILURE",
now.Add(20*time.Second),
now.Add(24*time.Hour),
))
reclaimed, err = repo.TryReclaim(
ctx,
record.ID,
service.IdempotencyStatusFailedRetryable,
now,
now.Add(40*time.Second),
now.Add(24*time.Hour),
)
require.NoError(t, err)
require.False(t, reclaimed, "within lock window should not reclaim")
}
func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-success"),
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-success"),
RequestFingerprint: hashedTestValue(t, "idem-fp-success"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(10 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour)))
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, service.IdempotencyStatusSucceeded, got.Status)
require.NotNil(t, got.ResponseStatus)
require.Equal(t, 200, *got.ResponseStatus)
require.NotNil(t, got.ResponseBody)
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
require.Nil(t, got.LockedUntil)
}
......@@ -48,6 +48,11 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
// security_secrets table should exist
var securitySecretsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass))
require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist")
// user_allowed_groups table should exist
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
......
......@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/url"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
......@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
if strings.TrimSpace(clientID) != "" {
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
}
clientIDs := []string{
openai.ClientID,
openai.SoraClientID,
}
seen := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
clientID = strings.TrimSpace(clientID)
if clientID == "" {
continue
}
if _, ok := seen[clientID]; ok {
continue
}
seen[clientID] = struct{}{}
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err == nil {
return tokenResp, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", openai.ClientID)
formData.Set("client_id", clientID)
formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse
......
......@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.ClientID {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "invalid_grant")
return
}
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID != customClientID {
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-custom", resp.AccessToken)
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment