Commit 3d79773b authored by kyx236's avatar kyx236
Browse files

Merge branch 'main' of https://github.com/james-6-23/sub2api

parents 6aa8cbbf 742e73c9
......@@ -16,6 +16,7 @@ var DefaultModels = []Model{
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.
......
......@@ -6,10 +6,14 @@ import (
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
type OAuthConfig struct {
......@@ -164,15 +168,24 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
}
// Fall back to built-in Gemini CLI OAuth client when not configured.
// SECURITY: This repo does not embed the built-in client secret; it must be provided via env.
if effective.ClientID == "" && effective.ClientSecret == "" {
secret := strings.TrimSpace(GeminiCLIOAuthClientSecret)
if secret == "" {
if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok {
secret = strings.TrimSpace(v)
}
}
if secret == "" {
return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv)
}
effective.ClientID = GeminiCLIOAuthClientID
effective.ClientSecret = GeminiCLIOAuthClientSecret
effective.ClientSecret = secret
} else if effective.ClientID == "" || effective.ClientSecret == "" {
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
return OAuthConfig{}, infraerrors.New(http.StatusBadRequest, "GEMINI_OAUTH_CLIENT_NOT_CONFIGURED", "OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
}
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
effective.ClientSecret == GeminiCLIOAuthClientSecret
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID
if effective.Scopes == "" {
// Use different default scopes based on OAuth type
......
package geminicli
import (
"encoding/hex"
"strings"
"sync"
"testing"
"time"
)
// ---------------------------------------------------------------------------
// SessionStore 测试
// ---------------------------------------------------------------------------
func TestSessionStore_SetAndGet(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "test-state",
OAuthType: "code_assist",
CreatedAt: time.Now(),
}
store.Set("sid-1", session)
got, ok := store.Get("sid-1")
if !ok {
t.Fatal("期望 Get 返回 ok=true,实际返回 false")
}
if got.State != "test-state" {
t.Errorf("期望 State=%q,实际=%q", "test-state", got.State)
}
}
func TestSessionStore_GetNotFound(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
_, ok := store.Get("不存在的ID")
if ok {
t.Error("期望不存在的 sessionID 返回 ok=false")
}
}
func TestSessionStore_GetExpired(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
// 创建一个已过期的 session(CreatedAt 设置为 SessionTTL+1 分钟之前)
session := &OAuthSession{
State: "expired-state",
OAuthType: "code_assist",
CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)),
}
store.Set("expired-sid", session)
_, ok := store.Get("expired-sid")
if ok {
t.Error("期望过期的 session 返回 ok=false")
}
}
func TestSessionStore_Delete(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "to-delete",
OAuthType: "code_assist",
CreatedAt: time.Now(),
}
store.Set("del-sid", session)
// 先确认存在
if _, ok := store.Get("del-sid"); !ok {
t.Fatal("删除前 session 应该存在")
}
store.Delete("del-sid")
if _, ok := store.Get("del-sid"); ok {
t.Error("删除后 session 不应该存在")
}
}
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
// 多次调用 Stop 不应 panic
store.Stop()
store.Stop()
store.Stop()
}
func TestSessionStore_ConcurrentAccess(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
const goroutines = 50
var wg sync.WaitGroup
wg.Add(goroutines * 3)
// 并发写入
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
sid := "concurrent-" + string(rune('A'+idx%26))
store.Set(sid, &OAuthSession{
State: sid,
OAuthType: "code_assist",
CreatedAt: time.Now(),
})
}(i)
}
// 并发读取
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
sid := "concurrent-" + string(rune('A'+idx%26))
store.Get(sid) // 可能找到也可能没找到,关键是不 panic
}(i)
}
// 并发删除
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
sid := "concurrent-" + string(rune('A'+idx%26))
store.Delete(sid)
}(i)
}
wg.Wait()
}
// ---------------------------------------------------------------------------
// GenerateRandomBytes 测试
// ---------------------------------------------------------------------------
func TestGenerateRandomBytes(t *testing.T) {
tests := []int{0, 1, 16, 32, 64}
for _, n := range tests {
b, err := GenerateRandomBytes(n)
if err != nil {
t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err)
continue
}
if len(b) != n {
t.Errorf("GenerateRandomBytes(%d) 返回长度=%d,期望=%d", n, len(b), n)
}
}
}
func TestGenerateRandomBytes_Uniqueness(t *testing.T) {
// 两次调用应该返回不同的结果(极小概率相同,32字节足够)
a, _ := GenerateRandomBytes(32)
b, _ := GenerateRandomBytes(32)
if string(a) == string(b) {
t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题")
}
}
// ---------------------------------------------------------------------------
// GenerateState 测试
// ---------------------------------------------------------------------------
func TestGenerateState(t *testing.T) {
state, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState() 出错: %v", err)
}
if state == "" {
t.Error("GenerateState() 返回空字符串")
}
// base64url 编码不应包含 padding '='
if strings.Contains(state, "=") {
t.Errorf("GenerateState() 结果包含 '=' padding: %s", state)
}
// base64url 不应包含 '+' 或 '/'
if strings.ContainsAny(state, "+/") {
t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state)
}
}
// ---------------------------------------------------------------------------
// GenerateSessionID 测试
// ---------------------------------------------------------------------------
func TestGenerateSessionID(t *testing.T) {
sid, err := GenerateSessionID()
if err != nil {
t.Fatalf("GenerateSessionID() 出错: %v", err)
}
// 16 字节 -> 32 个 hex 字符
if len(sid) != 32 {
t.Errorf("GenerateSessionID() 长度=%d,期望=32", len(sid))
}
// 必须是合法的 hex 字符串
if _, err := hex.DecodeString(sid); err != nil {
t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err)
}
}
func TestGenerateSessionID_Uniqueness(t *testing.T) {
a, _ := GenerateSessionID()
b, _ := GenerateSessionID()
if a == b {
t.Error("两次 GenerateSessionID() 返回了相同结果")
}
}
// ---------------------------------------------------------------------------
// GenerateCodeVerifier 测试
// ---------------------------------------------------------------------------
func TestGenerateCodeVerifier(t *testing.T) {
verifier, err := GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier() 出错: %v", err)
}
if verifier == "" {
t.Error("GenerateCodeVerifier() 返回空字符串")
}
// RFC 7636 要求 code_verifier 至少 43 个字符
if len(verifier) < 43 {
t.Errorf("GenerateCodeVerifier() 长度=%d,RFC 7636 要求至少 43 字符", len(verifier))
}
// base64url 编码不应包含 padding 和非 URL 安全字符
if strings.Contains(verifier, "=") {
t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier)
}
if strings.ContainsAny(verifier, "+/") {
t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier)
}
}
// ---------------------------------------------------------------------------
// GenerateCodeChallenge 测试
// ---------------------------------------------------------------------------
func TestGenerateCodeChallenge(t *testing.T) {
// 使用已知输入验证输出
// RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
// 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
challenge := GenerateCodeChallenge(verifier)
if challenge != expected {
t.Errorf("GenerateCodeChallenge(%q) = %q,期望 %q", verifier, challenge, expected)
}
}
func TestGenerateCodeChallenge_NoPadding(t *testing.T) {
challenge := GenerateCodeChallenge("test-verifier-string")
if strings.Contains(challenge, "=") {
t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge)
}
}
// ---------------------------------------------------------------------------
// base64URLEncode 测试
// ---------------------------------------------------------------------------
func TestBase64URLEncode(t *testing.T) {
tests := []struct {
name string
input []byte
}{
{"空字节", []byte{}},
{"单字节", []byte{0xff}},
{"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}},
{"全零", []byte{0x00, 0x00, 0x00}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := base64URLEncode(tt.input)
// 不应包含 '=' padding
if strings.Contains(result, "=") {
t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result)
}
// 不应包含标准 base64 的 '+' 或 '/'
if strings.ContainsAny(result, "+/") {
t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result)
}
})
}
}
// ---------------------------------------------------------------------------
// hasRestrictedScope 测试
// ---------------------------------------------------------------------------
func TestHasRestrictedScope(t *testing.T) {
tests := []struct {
scope string
expected bool
}{
// 受限 scope
{"https://www.googleapis.com/auth/generative-language", true},
{"https://www.googleapis.com/auth/generative-language.retriever", true},
{"https://www.googleapis.com/auth/generative-language.tuning", true},
{"https://www.googleapis.com/auth/drive", true},
{"https://www.googleapis.com/auth/drive.readonly", true},
{"https://www.googleapis.com/auth/drive.file", true},
// 非受限 scope
{"https://www.googleapis.com/auth/cloud-platform", false},
{"https://www.googleapis.com/auth/userinfo.email", false},
{"https://www.googleapis.com/auth/userinfo.profile", false},
// 边界情况
{"", false},
{"random-scope", false},
}
for _, tt := range tests {
t.Run(tt.scope, func(t *testing.T) {
got := hasRestrictedScope(tt.scope)
if got != tt.expected {
t.Errorf("hasRestrictedScope(%q) = %v,期望 %v", tt.scope, got, tt.expected)
}
})
}
}
// ---------------------------------------------------------------------------
// BuildAuthorizationURL 测试
// ---------------------------------------------------------------------------
func TestBuildAuthorizationURL(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
authURL, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"https://example.com/callback",
"",
"code_assist",
)
if err != nil {
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
}
// 检查返回的 URL 包含期望的参数
checks := []string{
"response_type=code",
"client_id=" + GeminiCLIOAuthClientID,
"redirect_uri=",
"state=test-state",
"code_challenge=test-challenge",
"code_challenge_method=S256",
"access_type=offline",
"prompt=consent",
"include_granted_scopes=true",
}
for _, check := range checks {
if !strings.Contains(authURL, check) {
t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL)
}
}
// 不应包含 project_id(因为传的是空字符串)
if strings.Contains(authURL, "project_id=") {
t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数")
}
// URL 应该以正确的授权端点开头
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL)
}
}
func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
_, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"", // 空 redirectURI
"",
"code_assist",
)
if err == nil {
t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错")
}
if !strings.Contains(err.Error(), "redirect_uri") {
t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err)
}
}
func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
authURL, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"https://example.com/callback",
"my-project-123",
"code_assist",
)
if err != nil {
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
}
if !strings.Contains(authURL, "project_id=my-project-123") {
t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL)
}
}
func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
authURL, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"https://example.com/callback",
"",
"code_assist",
)
if err != nil {
t.Fatalf("BuildAuthorizationURL() 不应报错: %v", err)
}
if !strings.Contains(authURL, "client_id="+GeminiCLIOAuthClientID) {
t.Errorf("应使用内置 Gemini CLI client_id,实际 URL: %s", authURL)
}
}
// ---------------------------------------------------------------------------
// EffectiveOAuthConfig 测试 - 原有测试
// ---------------------------------------------------------------------------
func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
// 内置的 Gemini CLI client secret 不嵌入在此仓库中。
// 测试通过环境变量设置一个假的 secret 来模拟运维配置。
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
tests := []struct {
name string
input OAuthConfig
......@@ -15,7 +445,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr bool
}{
{
name: "Google One with built-in client (empty config)",
name: "Google One 使用内置客户端(空配置)",
input: OAuthConfig{},
oauthType: "google_one",
wantClientID: GeminiCLIOAuthClientID,
......@@ -23,18 +453,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false,
},
{
name: "Google One always uses built-in client (even if custom credentials passed)",
name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)",
input: OAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
wantScopes: DefaultCodeAssistScopes,
wantErr: false,
},
{
name: "Google One with built-in client and custom scopes (should filter restricted scopes)",
name: "Google One 内置客户端 + 自定义 scopes(应过滤受限 scopes",
input: OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
},
......@@ -44,7 +474,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false,
},
{
name: "Google One with built-in client and only restricted scopes (should fallback to default)",
name: "Google One 内置客户端 + 仅受限 scopes(应回退到默认)",
input: OAuthConfig{
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
},
......@@ -54,7 +484,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false,
},
{
name: "Code Assist with built-in client",
name: "Code Assist 使用内置客户端",
input: OAuthConfig{},
oauthType: "code_assist",
wantClientID: GeminiCLIOAuthClientID,
......@@ -84,7 +514,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
}
func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
// Test that Google One with built-in client filters out restricted scopes
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 测试 Google One + 内置客户端过滤受限 scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile",
}, "google_one")
......@@ -93,21 +525,242 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// Should only contain cloud-platform, userinfo.email, and userinfo.profile
// Should NOT contain generative-language or drive scopes
// 应仅包含 cloud-platformuserinfo.email userinfo.profile
// 不应包含 generative-language drive scopes
if strings.Contains(cfg.Scopes, "generative-language") {
t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes)
t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language,实际: %v", cfg.Scopes)
}
if strings.Contains(cfg.Scopes, "drive") {
t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes)
t.Errorf("使用内置客户端时 Scopes 不应包含 drive,实际: %v", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes)
t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "userinfo.email") {
t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes)
t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "userinfo.profile") {
t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes)
t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes)
}
}
// ---------------------------------------------------------------------------
// EffectiveOAuthConfig 测试 - 新增分支覆盖
// ---------------------------------------------------------------------------
func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) {
// 只提供 clientID 不提供 secret 应报错
_, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "some-client-id",
}, "code_assist")
if err == nil {
t.Error("只提供 ClientID 不提供 ClientSecret 应该报错")
}
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
}
}
func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) {
// 只提供 secret 不提供 clientID 应报错
_, err := EffectiveOAuthConfig(OAuthConfig{
ClientSecret: "some-client-secret",
}, "code_assist")
if err == nil {
t.Error("只提供 ClientSecret 不提供 ClientID 应该报错")
}
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
}
}
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// ai_studio 类型,使用内置客户端,scopes 为空 -> 应使用 DefaultCodeAssistScopes(因为内置客户端不能请求 generative-language scope)
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultCodeAssistScopes {
t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) {
// ai_studio 类型,使用自定义客户端,scopes 为空 -> 应使用 DefaultAIStudioScopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultAIStudioScopes {
t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) {
// ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform",
}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") {
// 确保不包含未归一化的旧 scope(仅 generative-language 而非 generative-language.retriever)
parts := strings.Fields(cfg.Scopes)
for _, p := range parts {
if p == "https://www.googleapis.com/auth/generative-language" {
t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever,实际 scopes: %q", cfg.Scopes)
}
}
}
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
t.Errorf("ai_studio 归一化后应包含 generative-language.retriever,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 逗号分隔的 scopes 应被归一化为空格分隔
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email",
}, "code_assist")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// 应该用空格分隔,而非逗号
if strings.Contains(cfg.Scopes, ",") {
t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("归一化后应包含 cloud-platform,实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "userinfo.email") {
t.Errorf("归一化后应包含 userinfo.email,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) {
// 混合逗号和空格分隔的 scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile",
}, "code_assist")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
parts := strings.Fields(cfg.Scopes)
if len(parts) != 3 {
t.Errorf("归一化后应有 3 个 scope,实际: %d,scopes: %q", len(parts), cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
// 输入中的前后空白应被清理
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: " custom-id ",
ClientSecret: " custom-secret ",
Scopes: " https://www.googleapis.com/auth/cloud-platform ",
}, "code_assist")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.ClientID != "custom-id" {
t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID)
}
if cfg.ClientSecret != "custom-secret" {
t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret)
}
if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" {
t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
if err != nil {
t.Fatalf("不设置环境变量时应回退到内置 secret,实际报错: %v", err)
}
if strings.TrimSpace(cfg.ClientSecret) == "" {
t.Error("ClientSecret 不应为空")
}
if cfg.ClientID != GeminiCLIOAuthClientID {
t.Errorf("ClientID 应回退为内置客户端 ID,实际: %q", cfg.ClientID)
}
}
func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever",
}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// 内置客户端应过滤 generative-language.retriever
if strings.Contains(cfg.Scopes, "generative-language") {
t.Errorf("ai_studio + 内置客户端应过滤受限 scopes,实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("应保留 cloud-platform scope,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 未知的 oauthType 应回退到默认的 code_assist scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultCodeAssistScopes {
t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultCodeAssistScopes {
t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) {
// 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端)
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
}, "google_one")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// 自定义客户端不应过滤任何 scope
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
t.Errorf("自定义客户端不应过滤 generative-language.retriever,实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "drive.readonly") {
t.Errorf("自定义客户端不应过滤 drive.readonly,实际: %q", cfg.Scopes)
}
}
......@@ -18,11 +18,11 @@ package httpclient
import (
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
......@@ -32,6 +32,7 @@ const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL
)
// Options 定义共享 HTTP 客户端的构建参数
......@@ -40,7 +41,6 @@ type Options struct {
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true)
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
......@@ -53,6 +53,9 @@ type Options struct {
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
// 允许测试替换校验函数,生产默认指向真实实现。
var validateResolvedIP = urlvalidator.ValidateResolvedIP
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
......@@ -84,7 +87,7 @@ func buildClient(opts Options) (*http.Client, error) {
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport}
rt = newValidatedTransport(transport)
}
return &http.Client{
Transport: rt,
......@@ -116,15 +119,13 @@ func buildTransport(opts Options) (*http.Transport, error) {
return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead")
}
proxyURL := strings.TrimSpace(opts.ProxyURL)
if proxyURL == "" {
return transport, nil
}
parsed, err := url.Parse(proxyURL)
_, parsed, err := proxyurl.Parse(opts.ProxyURL)
if err != nil {
return nil, err
}
if parsed == nil {
return transport, nil
}
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
return nil, err
......@@ -134,12 +135,11 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.ValidateResolvedIP,
opts.AllowPrivateHosts,
opts.MaxIdleConns,
......@@ -150,16 +150,55 @@ func buildClientKey(opts Options) string {
type validatedTransport struct {
base http.RoundTripper
validatedHosts sync.Map // map[string]time.Time, value 为过期时间
now func() time.Time
}
func newValidatedTransport(base http.RoundTripper) *validatedTransport {
return &validatedTransport{
base: base,
now: time.Now,
}
}
func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool {
if t == nil {
return false
}
raw, ok := t.validatedHosts.Load(host)
if !ok {
return false
}
expireAt, ok := raw.(time.Time)
if !ok {
t.validatedHosts.Delete(host)
return false
}
if now.Before(expireAt) {
return true
}
t.validatedHosts.Delete(host)
return false
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname())
host := strings.ToLower(strings.TrimSpace(req.URL.Hostname()))
if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
now := time.Now()
if t != nil && t.now != nil {
now = t.now()
}
if !t.isValidatedHost(host, now) {
if err := validateResolvedIP(host); err != nil {
return nil, err
}
t.validatedHosts.Store(host, now.Add(validatedHostTTL))
}
}
}
if t == nil || t.base == nil {
return nil, fmt.Errorf("validated transport base is nil")
}
return t.base.RoundTrip(req)
}
package httpclient
import (
"errors"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestValidatedTransport_CacheHostValidation(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
var validateCalls int32
validateResolvedIP = func(host string) error {
atomic.AddInt32(&validateCalls, 1)
require.Equal(t, "api.openai.com", host)
return nil
}
var baseCalls int32
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&baseCalls, 1)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{}`)),
Header: make(http.Header),
}, nil
})
now := time.Unix(1730000000, 0)
transport := newValidatedTransport(base)
transport.now = func() time.Time { return now }
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls))
require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls))
}
func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
var validateCalls int32
validateResolvedIP = func(_ string) error {
atomic.AddInt32(&validateCalls, 1)
return nil
}
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{}`)),
Header: make(http.Header),
}, nil
})
now := time.Unix(1730001000, 0)
transport := newValidatedTransport(base)
transport.now = func() time.Time { return now }
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
now = now.Add(validatedHostTTL + time.Second)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls))
}
func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
expectedErr := errors.New("dns rebinding rejected")
validateResolvedIP = func(_ string) error {
return expectedErr
}
var baseCalls int32
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&baseCalls, 1)
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil
})
transport := newValidatedTransport(base)
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.ErrorIs(t, err, expectedErr)
require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls))
}
package httputil
import (
"bytes"
"io"
"net/http"
)
const (
requestBodyReadInitCap = 512
requestBodyReadMaxInitCap = 1 << 20
)
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
}
capHint := requestBodyReadInitCap
if req.ContentLength > 0 {
switch {
case req.ContentLength < int64(requestBodyReadInitCap):
capHint = requestBodyReadInitCap
case req.ContentLength > int64(requestBodyReadMaxInitCap):
capHint = requestBodyReadMaxInitCap
default:
capHint = int(req.ContentLength)
}
}
buf := bytes.NewBuffer(make([]byte, 0, capHint))
if _, err := io.Copy(buf, req.Body); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
......@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
return normalizeIP(c.ClientIP())
}
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
// 适用于 ACL / 风控等安全敏感场景。
func GetTrustedClientIP(c *gin.Context) string {
if c == nil {
return ""
}
return normalizeIP(c.ClientIP())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。
func normalizeIP(ip string) string {
ip = strings.TrimSpace(ip)
......@@ -54,29 +64,89 @@ func normalizeIP(ip string) string {
return ip
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
var privateNets []*net.IPNet
// CompiledIPRules 表示预编译的 IP 匹配规则。
// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。
type CompiledIPRules struct {
CIDRs []*net.IPNet
IPs []net.IP
PatternCount int
}
// 私有 IP 范围
privateBlocks := []string{
func init() {
for _, cidr := range []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
} {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
panic("invalid CIDR: " + cidr)
}
privateNets = append(privateNets, block)
}
}
for _, block := range privateBlocks {
_, cidr, err := net.ParseCIDR(block)
if err != nil {
// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。
// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。
func CompileIPRules(patterns []string) *CompiledIPRules {
compiled := &CompiledIPRules{
CIDRs: make([]*net.IPNet, 0, len(patterns)),
IPs: make([]net.IP, 0, len(patterns)),
PatternCount: len(patterns),
}
for _, pattern := range patterns {
normalized := strings.TrimSpace(pattern)
if normalized == "" {
continue
}
if strings.Contains(normalized, "/") {
_, cidr, err := net.ParseCIDR(normalized)
if err != nil || cidr == nil {
continue
}
compiled.CIDRs = append(compiled.CIDRs, cidr)
continue
}
parsedIP := net.ParseIP(normalized)
if parsedIP == nil {
continue
}
if cidr.Contains(ip) {
compiled.IPs = append(compiled.IPs, parsedIP)
}
return compiled
}
func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool {
if parsedIP == nil || rules == nil {
return false
}
for _, cidr := range rules.CIDRs {
if cidr.Contains(parsedIP) {
return true
}
}
for _, ruleIP := range rules.IPs {
if parsedIP.Equal(ruleIP) {
return true
}
}
return false
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
for _, block := range privateNets {
if block.Contains(ip) {
return true
}
}
......@@ -127,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
// 2. 如果白名单不为空,IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
return CheckIPRestrictionWithCompiledRules(
clientIP,
CompileIPRules(whitelist),
CompileIPRules(blacklist),
)
}
// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。
func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) {
// 规范化 IP
clientIP = normalizeIP(clientIP)
if clientIP == "" {
return false, "access denied"
}
parsedIP := net.ParseIP(clientIP)
if parsedIP == nil {
return false, "access denied"
}
// 1. 检查黑名单
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) {
return false, "access denied"
}
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) {
return false, "access denied"
}
......
//go:build unit
package ip
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
name string
ip string
expected bool
}{
// 私有 IPv4
{"10.x 私有地址", "10.0.0.1", true},
{"10.x 私有地址段末", "10.255.255.255", true},
{"172.16.x 私有地址", "172.16.0.1", true},
{"172.31.x 私有地址", "172.31.255.255", true},
{"192.168.x 私有地址", "192.168.1.1", true},
{"127.0.0.1 本地回环", "127.0.0.1", true},
{"127.x 回环段", "127.255.255.255", true},
// 公网 IPv4
{"8.8.8.8 公网 DNS", "8.8.8.8", false},
{"1.1.1.1 公网", "1.1.1.1", false},
{"172.15.255.255 非私有", "172.15.255.255", false},
{"172.32.0.0 非私有", "172.32.0.0", false},
{"11.0.0.1 公网", "11.0.0.1", false},
// IPv6
{"::1 IPv6 回环", "::1", true},
{"fc00:: IPv6 私有", "fc00::1", true},
{"fd00:: IPv6 私有", "fd00::1", true},
{"2001:db8::1 IPv6 公网", "2001:db8::1", false},
// 无效输入
{"空字符串", "", false},
{"非法字符串", "not-an-ip", false},
{"不完整 IP", "192.168", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := isPrivateIP(tc.ip)
require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip)
})
}
}
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
require.NoError(t, r.SetTrustedProxies(nil))
r.GET("/t", func(c *gin.Context) {
c.String(200, GetTrustedClientIP(c))
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/t", nil)
req.RemoteAddr = "9.9.9.9:12345"
req.Header.Set("X-Forwarded-For", "1.2.3.4")
req.Header.Set("X-Real-IP", "1.2.3.4")
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
r.ServeHTTP(w, req)
require.Equal(t, 200, w.Code)
require.Equal(t, "9.9.9.9", w.Body.String())
}
func TestCheckIPRestrictionWithCompiledRules(t *testing.T) {
whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"})
blacklist := CompileIPRules([]string{"10.1.1.1"})
allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist)
require.True(t, allowed)
require.Equal(t, "", reason)
allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist)
require.False(t, allowed)
require.Equal(t, "access denied", reason)
}
func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) {
// 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。
invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"})
allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil)
require.False(t, allowed)
require.Equal(t, "access denied", reason)
}
package logger
import "github.com/Wei-Shaw/sub2api/internal/config"
func OptionsFromConfig(cfg config.LogConfig) InitOptions {
return InitOptions{
Level: cfg.Level,
Format: cfg.Format,
ServiceName: cfg.ServiceName,
Environment: cfg.Environment,
Caller: cfg.Caller,
StacktraceLevel: cfg.StacktraceLevel,
Output: OutputOptions{
ToStdout: cfg.Output.ToStdout,
ToFile: cfg.Output.ToFile,
FilePath: cfg.Output.FilePath,
},
Rotation: RotationOptions{
MaxSizeMB: cfg.Rotation.MaxSizeMB,
MaxBackups: cfg.Rotation.MaxBackups,
MaxAgeDays: cfg.Rotation.MaxAgeDays,
Compress: cfg.Rotation.Compress,
LocalTime: cfg.Rotation.LocalTime,
},
Sampling: SamplingOptions{
Enabled: cfg.Sampling.Enabled,
Initial: cfg.Sampling.Initial,
Thereafter: cfg.Sampling.Thereafter,
},
}
}
package logger
import (
"context"
"fmt"
"io"
"log"
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
type Level = zapcore.Level
const (
LevelDebug = zapcore.DebugLevel
LevelInfo = zapcore.InfoLevel
LevelWarn = zapcore.WarnLevel
LevelError = zapcore.ErrorLevel
LevelFatal = zapcore.FatalLevel
)
type Sink interface {
WriteLogEvent(event *LogEvent)
}
type LogEvent struct {
Time time.Time
Level string
Component string
Message string
LoggerName string
Fields map[string]any
}
var (
mu sync.RWMutex
global atomic.Pointer[zap.Logger]
sugar atomic.Pointer[zap.SugaredLogger]
atomicLevel zap.AtomicLevel
initOptions InitOptions
currentSink atomic.Value // sinkState
stdLogUndo func()
bootstrapOnce sync.Once
)
type sinkState struct {
sink Sink
}
func InitBootstrap() {
bootstrapOnce.Do(func() {
if err := Init(bootstrapOptions()); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err)
}
})
}
func Init(options InitOptions) error {
mu.Lock()
defer mu.Unlock()
return initLocked(options)
}
func initLocked(options InitOptions) error {
normalized := options.normalized()
zl, al, err := buildLogger(normalized)
if err != nil {
return err
}
prev := global.Load()
global.Store(zl)
sugar.Store(zl.Sugar())
atomicLevel = al
initOptions = normalized
bridgeSlogLocked()
bridgeStdLogLocked()
if prev != nil {
_ = prev.Sync()
}
return nil
}
func Reconfigure(mutator func(*InitOptions) error) error {
mu.Lock()
defer mu.Unlock()
next := initOptions
if mutator != nil {
if err := mutator(&next); err != nil {
return err
}
}
return initLocked(next)
}
func SetLevel(level string) error {
lv, ok := parseLevel(level)
if !ok {
return fmt.Errorf("invalid log level: %s", level)
}
mu.Lock()
defer mu.Unlock()
atomicLevel.SetLevel(lv)
initOptions.Level = strings.ToLower(strings.TrimSpace(level))
return nil
}
func CurrentLevel() string {
mu.RLock()
defer mu.RUnlock()
if global.Load() == nil {
return "info"
}
return atomicLevel.Level().String()
}
func SetSink(sink Sink) {
currentSink.Store(sinkState{sink: sink})
}
func loadSink() Sink {
v := currentSink.Load()
if v == nil {
return nil
}
state, ok := v.(sinkState)
if !ok {
return nil
}
return state.sink
}
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
func WriteSinkEvent(level, component, message string, fields map[string]any) {
sink := loadSink()
if sink == nil {
return
}
level = strings.ToLower(strings.TrimSpace(level))
if level == "" {
level = "info"
}
component = strings.TrimSpace(component)
message = strings.TrimSpace(message)
if message == "" {
return
}
eventFields := make(map[string]any, len(fields)+1)
for k, v := range fields {
eventFields[k] = v
}
if component != "" {
if _, ok := eventFields["component"]; !ok {
eventFields["component"] = component
}
}
sink.WriteLogEvent(&LogEvent{
Time: time.Now(),
Level: level,
Component: component,
Message: message,
LoggerName: component,
Fields: eventFields,
})
}
func L() *zap.Logger {
if l := global.Load(); l != nil {
return l
}
return zap.NewNop()
}
func S() *zap.SugaredLogger {
if s := sugar.Load(); s != nil {
return s
}
return zap.NewNop().Sugar()
}
func With(fields ...zap.Field) *zap.Logger {
return L().With(fields...)
}
func Sync() {
l := global.Load()
if l != nil {
_ = l.Sync()
}
}
func bridgeStdLogLocked() {
if stdLogUndo != nil {
stdLogUndo()
stdLogUndo = nil
}
prevFlags := log.Flags()
prevPrefix := log.Prefix()
prevWriter := log.Writer()
log.SetFlags(0)
log.SetPrefix("")
base := global.Load()
if base == nil {
base = zap.NewNop()
}
log.SetOutput(newStdLogBridge(base.Named("stdlog")))
stdLogUndo = func() {
log.SetOutput(prevWriter)
log.SetFlags(prevFlags)
log.SetPrefix(prevPrefix)
}
}
func bridgeSlogLocked() {
base := global.Load()
if base == nil {
base = zap.NewNop()
}
slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog"))))
}
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
level, _ := parseLevel(options.Level)
atomic := zap.NewAtomicLevelAt(level)
encoderCfg := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
MessageKey: "msg",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.CapitalLevelEncoder,
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.MillisDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
var enc zapcore.Encoder
if options.Format == "console" {
enc = zapcore.NewConsoleEncoder(encoderCfg)
} else {
enc = zapcore.NewJSONEncoder(encoderCfg)
}
sinkCore := newSinkCore()
cores := make([]zapcore.Core, 0, 3)
if options.Output.ToStdout {
infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= atomic.Level() && lvl < zapcore.WarnLevel
})
errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel
})
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority))
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority))
}
if options.Output.ToFile {
fileCore, filePath, fileErr := buildFileCore(enc, atomic, options)
if fileErr != nil {
_, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n",
time.Now().Format(time.RFC3339Nano),
filePath,
fileErr,
)
} else {
cores = append(cores, fileCore)
}
}
if len(cores) == 0 {
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic))
}
core := zapcore.NewTee(cores...)
if options.Sampling.Enabled {
core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter)
}
core = sinkCore.Wrap(core)
stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel)
zapOpts := make([]zap.Option, 0, 5)
if options.Caller {
zapOpts = append(zapOpts, zap.AddCaller())
}
if stacktraceLevel <= zapcore.FatalLevel {
zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel))
}
logger := zap.New(core, zapOpts...).With(
zap.String("service", options.ServiceName),
zap.String("env", options.Environment),
)
return logger, atomic, nil
}
func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) {
filePath := options.Output.FilePath
if strings.TrimSpace(filePath) == "" {
filePath = resolveLogFilePath("")
}
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, filePath, err
}
lj := &lumberjack.Logger{
Filename: filePath,
MaxSize: options.Rotation.MaxSizeMB,
MaxBackups: options.Rotation.MaxBackups,
MaxAge: options.Rotation.MaxAgeDays,
Compress: options.Rotation.Compress,
LocalTime: options.Rotation.LocalTime,
}
return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil
}
type sinkCore struct {
core zapcore.Core
fields []zapcore.Field
}
func newSinkCore() *sinkCore {
return &sinkCore{}
}
func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core {
cp := *s
cp.core = core
return &cp
}
func (s *sinkCore) Enabled(level zapcore.Level) bool {
return s.core.Enabled(level)
}
func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core {
nextFields := append([]zapcore.Field{}, s.fields...)
nextFields = append(nextFields, fields...)
return &sinkCore{
core: s.core.With(fields),
fields: nextFields,
}
}
func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
// Delegate to inner core (tee) so each sub-core's level enabler is respected.
// Then add ourselves for sink forwarding only.
ce = s.core.Check(entry, ce)
if ce != nil {
ce = ce.AddCore(entry, s)
}
return ce
}
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
// Only handle sink forwarding — the inner cores write via their own
// Write methods (added to CheckedEntry by s.core.Check above).
sink := loadSink()
if sink == nil {
return nil
}
enc := zapcore.NewMapObjectEncoder()
for _, f := range s.fields {
f.AddTo(enc)
}
for _, f := range fields {
f.AddTo(enc)
}
event := &LogEvent{
Time: entry.Time,
Level: strings.ToLower(entry.Level.String()),
Component: entry.LoggerName,
Message: entry.Message,
LoggerName: entry.LoggerName,
Fields: enc.Fields,
}
sink.WriteLogEvent(event)
return nil
}
func (s *sinkCore) Sync() error {
return s.core.Sync()
}
type stdLogBridge struct {
logger *zap.Logger
}
func newStdLogBridge(l *zap.Logger) io.Writer {
if l == nil {
l = zap.NewNop()
}
return &stdLogBridge{logger: l}
}
func (b *stdLogBridge) Write(p []byte) (int, error) {
msg := normalizeStdLogMessage(string(p))
if msg == "" {
return len(p), nil
}
level := inferStdLogLevel(msg)
entry := b.logger.WithOptions(zap.AddCallerSkip(4))
switch level {
case LevelDebug:
entry.Debug(msg, zap.Bool("legacy_stdlog", true))
case LevelWarn:
entry.Warn(msg, zap.Bool("legacy_stdlog", true))
case LevelError, LevelFatal:
entry.Error(msg, zap.Bool("legacy_stdlog", true))
default:
entry.Info(msg, zap.Bool("legacy_stdlog", true))
}
return len(p), nil
}
func normalizeStdLogMessage(raw string) string {
msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " "))
if msg == "" {
return ""
}
return strings.Join(strings.Fields(msg), " ")
}
func inferStdLogLevel(msg string) Level {
lower := strings.ToLower(strings.TrimSpace(msg))
if lower == "" {
return LevelInfo
}
if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") {
return LevelDebug
}
if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") {
return LevelWarn
}
if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") {
return LevelError
}
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
return LevelError
}
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
return LevelWarn
}
return LevelInfo
}
// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。
func LegacyPrintf(component, format string, args ...any) {
msg := normalizeStdLogMessage(fmt.Sprintf(format, args...))
if msg == "" {
return
}
initialized := global.Load() != nil
if !initialized {
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
log.Print(msg)
return
}
l := L()
if component != "" {
l = l.With(zap.String("component", component))
}
l = l.WithOptions(zap.AddCallerSkip(1))
switch inferStdLogLevel(msg) {
case LevelDebug:
l.Debug(msg, zap.Bool("legacy_printf", true))
case LevelWarn:
l.Warn(msg, zap.Bool("legacy_printf", true))
case LevelError, LevelFatal:
l.Error(msg, zap.Bool("legacy_printf", true))
default:
l.Info(msg, zap.Bool("legacy_printf", true))
}
}
type contextKey string
const loggerContextKey contextKey = "ctx_logger"
func IntoContext(ctx context.Context, l *zap.Logger) context.Context {
if ctx == nil {
ctx = context.Background()
}
if l == nil {
l = L()
}
return context.WithValue(ctx, loggerContextKey, l)
}
func FromContext(ctx context.Context) *zap.Logger {
if ctx == nil {
return L()
}
if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil {
return l
}
return L()
}
package logger
import (
"encoding/json"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestInit_DualOutput(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stderrR.Close()
_ = stdoutW.Close()
_ = stderrW.Close()
})
err = Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: true,
FilePath: logPath,
},
Rotation: RotationOptions{
MaxSizeMB: 10,
MaxBackups: 2,
MaxAgeDays: 1,
},
Sampling: SamplingOptions{Enabled: false},
})
if err != nil {
t.Fatalf("Init() error: %v", err)
}
L().Info("dual-output-info")
L().Warn("dual-output-warn")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "dual-output-info") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "dual-output-warn") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
fileBytes, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("read log file: %v", err)
}
fileText := string(fileBytes)
if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") {
t.Fatalf("file missing logs: %s", fileText)
}
}
func TestInit_FileOutputFailureDowngrade(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
_, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
err = Init(InitOptions{
Level: "info",
Format: "json",
Output: OutputOptions{
ToStdout: true,
ToFile: true,
FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"),
},
Rotation: RotationOptions{
MaxSizeMB: 10,
MaxBackups: 1,
MaxAgeDays: 1,
},
})
if err != nil {
t.Fatalf("Init() should downgrade instead of failing, got: %v", err)
}
_ = stderrW.Close()
stderrBytes, _ := io.ReadAll(stderrR)
if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") {
t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes))
}
}
func TestInit_CallerShouldPointToCallsite(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
_, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Caller: true,
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
L().Info("caller-check")
Sync()
_ = stdoutW.Close()
logBytes, _ := io.ReadAll(stdoutR)
var line string
for _, item := range strings.Split(string(logBytes), "\n") {
if strings.Contains(item, "caller-check") {
line = item
break
}
}
if line == "" {
t.Fatalf("log output missing caller-check: %s", string(logBytes))
}
var payload map[string]any
if err := json.Unmarshal([]byte(line), &payload); err != nil {
t.Fatalf("parse log json failed: %v, line=%s", err, line)
}
caller, _ := payload["caller"].(string)
if !strings.Contains(caller, "logger_test.go:") {
t.Fatalf("caller should point to this test file, got: %s", caller)
}
}
package logger
import (
"os"
"path/filepath"
"strings"
"time"
)
const (
// DefaultContainerLogPath 为容器内默认日志文件路径。
DefaultContainerLogPath = "/app/data/logs/sub2api.log"
defaultLogFilename = "sub2api.log"
)
type InitOptions struct {
Level string
Format string
ServiceName string
Environment string
Caller bool
StacktraceLevel string
Output OutputOptions
Rotation RotationOptions
Sampling SamplingOptions
}
type OutputOptions struct {
ToStdout bool
ToFile bool
FilePath string
}
type RotationOptions struct {
MaxSizeMB int
MaxBackups int
MaxAgeDays int
Compress bool
LocalTime bool
}
type SamplingOptions struct {
Enabled bool
Initial int
Thereafter int
}
func (o InitOptions) normalized() InitOptions {
out := o
out.Level = strings.ToLower(strings.TrimSpace(out.Level))
if out.Level == "" {
out.Level = "info"
}
out.Format = strings.ToLower(strings.TrimSpace(out.Format))
if out.Format == "" {
out.Format = "console"
}
out.ServiceName = strings.TrimSpace(out.ServiceName)
if out.ServiceName == "" {
out.ServiceName = "sub2api"
}
out.Environment = strings.TrimSpace(out.Environment)
if out.Environment == "" {
out.Environment = "production"
}
out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel))
if out.StacktraceLevel == "" {
out.StacktraceLevel = "error"
}
if !out.Output.ToStdout && !out.Output.ToFile {
out.Output.ToStdout = true
}
out.Output.FilePath = resolveLogFilePath(out.Output.FilePath)
if out.Rotation.MaxSizeMB <= 0 {
out.Rotation.MaxSizeMB = 100
}
if out.Rotation.MaxBackups < 0 {
out.Rotation.MaxBackups = 10
}
if out.Rotation.MaxAgeDays < 0 {
out.Rotation.MaxAgeDays = 7
}
if out.Sampling.Enabled {
if out.Sampling.Initial <= 0 {
out.Sampling.Initial = 100
}
if out.Sampling.Thereafter <= 0 {
out.Sampling.Thereafter = 100
}
}
return out
}
func resolveLogFilePath(explicit string) string {
explicit = strings.TrimSpace(explicit)
if explicit != "" {
return explicit
}
dataDir := strings.TrimSpace(os.Getenv("DATA_DIR"))
if dataDir != "" {
return filepath.Join(dataDir, "logs", defaultLogFilename)
}
return DefaultContainerLogPath
}
func bootstrapOptions() InitOptions {
return InitOptions{
Level: "info",
Format: "console",
ServiceName: "sub2api",
Environment: "bootstrap",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Rotation: RotationOptions{
MaxSizeMB: 100,
MaxBackups: 10,
MaxAgeDays: 7,
Compress: true,
LocalTime: true,
},
Sampling: SamplingOptions{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
}
}
func parseLevel(level string) (Level, bool) {
switch strings.ToLower(strings.TrimSpace(level)) {
case "debug":
return LevelDebug, true
case "info":
return LevelInfo, true
case "warn":
return LevelWarn, true
case "error":
return LevelError, true
default:
return LevelInfo, false
}
}
func parseStacktraceLevel(level string) (Level, bool) {
switch strings.ToLower(strings.TrimSpace(level)) {
case "none":
return LevelFatal + 1, true
case "error":
return LevelError, true
case "fatal":
return LevelFatal, true
default:
return LevelError, false
}
}
func samplingTick() time.Duration {
return time.Second
}
package logger
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func TestResolveLogFilePath_Default(t *testing.T) {
t.Setenv("DATA_DIR", "")
got := resolveLogFilePath("")
if got != DefaultContainerLogPath {
t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath)
}
}
func TestResolveLogFilePath_WithDataDir(t *testing.T) {
t.Setenv("DATA_DIR", "/tmp/sub2api-data")
got := resolveLogFilePath("")
want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log")
if got != want {
t.Fatalf("resolveLogFilePath() = %q, want %q", got, want)
}
}
func TestResolveLogFilePath_ExplicitPath(t *testing.T) {
t.Setenv("DATA_DIR", "/tmp/ignore")
got := resolveLogFilePath("/var/log/custom.log")
if got != "/var/log/custom.log" {
t.Fatalf("resolveLogFilePath() = %q, want explicit path", got)
}
}
func TestNormalizedOptions_InvalidFallback(t *testing.T) {
t.Setenv("DATA_DIR", "")
opts := InitOptions{
Level: "TRACE",
Format: "TEXT",
ServiceName: "",
Environment: "",
StacktraceLevel: "panic",
Output: OutputOptions{
ToStdout: false,
ToFile: false,
},
Rotation: RotationOptions{
MaxSizeMB: 0,
MaxBackups: -1,
MaxAgeDays: -1,
},
Sampling: SamplingOptions{
Enabled: true,
Initial: 0,
Thereafter: 0,
},
}
out := opts.normalized()
if out.Level != "trace" {
// normalized 仅做 trim/lower,不做校验;校验在 config 层。
t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level)
}
if !out.Output.ToStdout {
t.Fatalf("normalized output should fallback to stdout")
}
if out.Output.FilePath != DefaultContainerLogPath {
t.Fatalf("normalized file path = %q", out.Output.FilePath)
}
if out.Rotation.MaxSizeMB != 100 {
t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB)
}
if out.Rotation.MaxBackups != 10 {
t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups)
}
if out.Rotation.MaxAgeDays != 7 {
t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays)
}
if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 {
t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling)
}
}
func TestBuildFileCore_InvalidPathFallback(t *testing.T) {
t.Setenv("DATA_DIR", "")
opts := bootstrapOptions()
opts.Output.ToFile = true
opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log")
encoderCfg := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
MessageKey: "msg",
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeLevel: zapcore.CapitalLevelEncoder,
}
encoder := zapcore.NewJSONEncoder(encoderCfg)
_, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts)
if err == nil {
t.Fatalf("buildFileCore() expected error for invalid path")
}
}
package logger
import (
"context"
"log/slog"
"strings"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type slogZapHandler struct {
logger *zap.Logger
attrs []slog.Attr
groups []string
}
func newSlogZapHandler(logger *zap.Logger) slog.Handler {
if logger == nil {
logger = zap.NewNop()
}
return &slogZapHandler{
logger: logger,
attrs: make([]slog.Attr, 0, 8),
groups: make([]string, 0, 4),
}
}
func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool {
switch {
case level >= slog.LevelError:
return h.logger.Core().Enabled(LevelError)
case level >= slog.LevelWarn:
return h.logger.Core().Enabled(LevelWarn)
case level <= slog.LevelDebug:
return h.logger.Core().Enabled(LevelDebug)
default:
return h.logger.Core().Enabled(LevelInfo)
}
}
func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3)
fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...)
record.Attrs(func(attr slog.Attr) bool {
fields = append(fields, slogAttrToZapField(h.groups, attr))
return true
})
switch {
case record.Level >= slog.LevelError:
h.logger.Error(record.Message, fields...)
case record.Level >= slog.LevelWarn:
h.logger.Warn(record.Message, fields...)
case record.Level <= slog.LevelDebug:
h.logger.Debug(record.Message, fields...)
default:
h.logger.Info(record.Message, fields...)
}
return nil
}
func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
next := *h
next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...)
return &next
}
func (h *slogZapHandler) WithGroup(name string) slog.Handler {
name = strings.TrimSpace(name)
if name == "" {
return h
}
next := *h
next.groups = append(append([]string{}, h.groups...), name)
return &next
}
func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field {
fields := make([]zap.Field, 0, len(attrs))
for _, attr := range attrs {
fields = append(fields, slogAttrToZapField(groups, attr))
}
return fields
}
func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field {
if len(groups) > 0 {
attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".")
}
value := attr.Value.Resolve()
switch value.Kind() {
case slog.KindBool:
return zap.Bool(attr.Key, value.Bool())
case slog.KindInt64:
return zap.Int64(attr.Key, value.Int64())
case slog.KindUint64:
return zap.Uint64(attr.Key, value.Uint64())
case slog.KindFloat64:
return zap.Float64(attr.Key, value.Float64())
case slog.KindDuration:
return zap.Duration(attr.Key, value.Duration())
case slog.KindTime:
return zap.Time(attr.Key, value.Time())
case slog.KindString:
return zap.String(attr.Key, value.String())
case slog.KindGroup:
groupFields := make([]zap.Field, 0, len(value.Group()))
for _, nested := range value.Group() {
groupFields = append(groupFields, slogAttrToZapField(nil, nested))
}
return zap.Object(attr.Key, zapObjectFields(groupFields))
case slog.KindAny:
if t, ok := value.Any().(time.Time); ok {
return zap.Time(attr.Key, t)
}
return zap.Any(attr.Key, value.Any())
default:
return zap.String(attr.Key, value.String())
}
}
type zapObjectFields []zap.Field
func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error {
for _, field := range z {
field.AddTo(enc)
}
return nil
}
package logger
import (
"context"
"log/slog"
"testing"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type captureState struct {
writes []capturedWrite
}
type capturedWrite struct {
fields []zapcore.Field
}
type captureCore struct {
state *captureState
withFields []zapcore.Field
}
func newCaptureCore() *captureCore {
return &captureCore{state: &captureState{}}
}
func (c *captureCore) Enabled(zapcore.Level) bool {
return true
}
func (c *captureCore) With(fields []zapcore.Field) zapcore.Core {
nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
nextFields = append(nextFields, c.withFields...)
nextFields = append(nextFields, fields...)
return &captureCore{
state: c.state,
withFields: nextFields,
}
}
func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
return ce.AddCore(entry, c)
}
func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
allFields = append(allFields, c.withFields...)
allFields = append(allFields, fields...)
c.state.writes = append(c.state.writes, capturedWrite{
fields: allFields,
})
return nil
}
func (c *captureCore) Sync() error {
return nil
}
func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) {
core := newCaptureCore()
handler := newSlogZapHandler(zap.New(core))
record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0)
record.AddAttrs(slog.String("component", "http.access"))
if err := handler.Handle(context.Background(), record); err != nil {
t.Fatalf("handle slog record: %v", err)
}
if len(core.state.writes) != 1 {
t.Fatalf("write calls = %d, want 1", len(core.state.writes))
}
var hasComponent bool
for _, field := range core.state.writes[0].fields {
if field.Key == "time" {
t.Fatalf("unexpected duplicate time field in slog adapter output")
}
if field.Key == "component" {
hasComponent = true
}
}
if !hasComponent {
t.Fatalf("component field should be preserved")
}
}
package logger
import (
"io"
"log"
"os"
"strings"
"testing"
)
func TestInferStdLogLevel(t *testing.T) {
cases := []struct {
msg string
want Level
}{
{msg: "Warning: queue full", want: LevelWarn},
{msg: "Forward request failed: timeout", want: LevelError},
{msg: "[ERROR] upstream unavailable", want: LevelError},
{msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo},
{msg: "service started", want: LevelInfo},
{msg: "debug: cache miss", want: LevelDebug},
}
for _, tc := range cases {
got := inferStdLogLevel(tc.msg)
if got != tc.want {
t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want)
}
}
}
func TestNormalizeStdLogMessage(t *testing.T) {
raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n"
got := normalizeStdLogMessage(raw)
want := "[TokenRefresh] cycle complete total=1 failed=0"
if got != want {
t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want)
}
}
func TestStdLogBridgeRoutesLevels(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
log.Printf("service started")
log.Printf("Warning: queue full")
log.Printf("Forward request failed: timeout")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "service started") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "Warning: queue full") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
if !strings.Contains(stderrText, "Forward request failed: timeout") {
t.Fatalf("stderr missing error log: %s", stderrText)
}
if !strings.Contains(stderrText, "\"legacy_stdlog\":true") {
t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText)
}
}
func TestLegacyPrintfRoutesLevels(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
LegacyPrintf("service.test", "request started")
LegacyPrintf("service.test", "Warning: queue full")
LegacyPrintf("service.test", "forward failed: timeout")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "request started") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "Warning: queue full") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
if !strings.Contains(stderrText, "forward failed: timeout") {
t.Fatalf("stderr missing error log: %s", stderrText)
}
if !strings.Contains(stderrText, "\"legacy_printf\":true") {
t.Fatalf("stderr missing legacy_printf marker: %s", stderrText)
}
if !strings.Contains(stderrText, "\"component\":\"service.test\"") {
t.Fatalf("stderr missing component field: %s", stderrText)
}
}
......@@ -50,6 +50,7 @@ type OAuthSession struct {
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopOnce sync.Once
stopCh chan struct{}
}
......@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore {
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
s.stopOnce.Do(func() {
close(s.stopCh)
})
}
// Set stores a session
......
package oauth
import (
"sync"
"testing"
"time"
)
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
store.Stop()
store.Stop()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
func TestSessionStore_Stop_Concurrent(t *testing.T) {
store := NewSessionStore()
var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
store.Stop()
}()
}
wg.Wait()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
......@@ -15,8 +15,8 @@ type Model struct {
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
......
......@@ -17,6 +17,8 @@ import (
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
......@@ -34,10 +36,18 @@ const (
SessionTTL = 30 * time.Minute
)
const (
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI = "openai"
// OAuthPlatformSora uses Sora OAuth client.
OAuthPlatformSora = "sora"
)
// OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ClientID string `json:"client_id,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"`
......@@ -47,6 +57,7 @@ type OAuthSession struct {
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopOnce sync.Once
stopCh chan struct{}
}
......@@ -92,7 +103,9 @@ func (s *SessionStore) Delete(sessionID string) {
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
s.stopOnce.Do(func() {
close(s.stopCh)
})
}
// cleanup removes expired sessions periodically
......@@ -169,13 +182,20 @@ func base64URLEncode(data []byte) string {
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI)
}
// BuildAuthorizationURLForPlatform builds authorization URL by platform.
func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
clientID, codexFlow := OAuthClientConfigByPlatform(platform)
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes)
params.Set("state", state)
......@@ -183,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
params.Set("code_challenge_method", "S256")
// OpenAI specific parameters
params.Set("id_token_add_organizations", "true")
if codexFlow {
params.Set("codex_cli_simplified_flow", "true")
}
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
switch strings.ToLower(strings.TrimSpace(platform)) {
case OAuthPlatformSora:
return ClientID, false
default:
return ClientID, true
}
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
......@@ -291,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
return params.Encode()
}
// ParseIDToken parses the ID Token JWT and extracts claims
// Note: This does NOT verify the signature - it only decodes the payload
// For production, you should verify the token signature using OpenAI's public keys
// ParseIDToken parses the ID Token JWT and extracts claims.
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
//
// https://auth.openai.com/.well-known/jwks.json
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
......@@ -324,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
const clockSkewTolerance = 120 // 秒
now := time.Now().Unix()
if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance {
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
}
return &claims, nil
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment