Commit fc8a39e0 authored by yangjianbo's avatar yangjianbo
Browse files

test: 删除CI工作流,大幅提升后端单元测试覆盖率至50%+



删除因GitHub计费锁定而失败的CI工作流。
为6个核心Go源文件补充单元测试,全部达到50%以上覆盖率:
- response/response.go: 97.6%
- antigravity/oauth.go: 90.1%
- antigravity/client.go: 88.6% (新增27个HTTP客户端测试)
- geminicli/oauth.go: 91.8%
- service/oauth_service.go: 61.2%
- service/gemini_oauth_service.go: 51.9%

新增/增强8个测试文件,共计5600+行测试代码。
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 9da80e9f
name: CI
on:
push:
pull_request:
permissions:
contents: read
jobs:
# ==========================================================================
# 后端测试(与前端并行运行)
# ==========================================================================
backend-test:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:16-alpine
env:
POSTGRES_USER: test
POSTGRES_PASSWORD: test
POSTGRES_DB: sub2api_test
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U test"
--health-interval 10s
--health-timeout 5s
--health-retries 5
redis:
image: redis:7-alpine
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: backend/go.mod
check-latest: false
cache: true
- name: 验证 Go 版本
run: go version | grep -q 'go1.25.7'
- name: 单元测试
working-directory: backend
run: make test-unit
- name: 集成测试
working-directory: backend
env:
DATABASE_URL: postgres://test:test@localhost:5432/sub2api_test?sslmode=disable
REDIS_URL: redis://localhost:6379/0
run: make test-integration
- name: Race 检测
working-directory: backend
run: go test -tags=unit -race -count=1 ./...
- name: 覆盖率收集
working-directory: backend
run: |
go test -tags=unit -coverprofile=coverage.out -count=1 ./...
echo "## 后端测试覆盖率" >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY
go tool cover -func=coverage.out | tail -1 >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY
- name: 覆盖率门禁(≥8%)
working-directory: backend
run: |
COVERAGE=$(go tool cover -func=coverage.out | tail -1 | awk '{print $3}' | sed 's/%//')
echo "当前覆盖率: ${COVERAGE}%"
if [ "$(echo "$COVERAGE < 8" | bc -l)" -eq 1 ]; then
echo "::error::后端覆盖率 ${COVERAGE}% 低于门禁值 8%"
exit 1
fi
# ==========================================================================
# 后端代码检查
# ==========================================================================
golangci-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: backend/go.mod
check-latest: false
cache: true
- name: 验证 Go 版本
run: go version | grep -q 'go1.25.7'
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.7
args: --timeout=5m
working-directory: backend
# ==========================================================================
# 前端测试(与后端并行运行)
# ==========================================================================
frontend-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: 安装 pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: 安装 Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
cache: 'pnpm'
cache-dependency-path: frontend/pnpm-lock.yaml
- name: 安装依赖
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: 类型检查
working-directory: frontend
run: pnpm run typecheck
- name: Lint 检查
working-directory: frontend
run: pnpm run lint:check
- name: 单元测试
working-directory: frontend
run: pnpm run test:run
- name: 覆盖率收集
working-directory: frontend
run: |
pnpm run test:coverage -- --exclude '**/integration/**' || true
echo "## 前端测试覆盖率" >> $GITHUB_STEP_SUMMARY
if [ -f coverage/coverage-final.json ]; then
echo "覆盖率报告已生成" >> $GITHUB_STEP_SUMMARY
fi
- name: 覆盖率门禁(≥20%)
working-directory: frontend
run: |
if [ ! -f coverage/coverage-final.json ]; then
echo "::warning::覆盖率报告未生成,跳过门禁检查"
exit 0
fi
# 使用 node 解析覆盖率 JSON
COVERAGE=$(node -e "
const data = require('./coverage/coverage-final.json');
let totalStatements = 0, coveredStatements = 0;
for (const file of Object.values(data)) {
const stmts = file.s;
totalStatements += Object.keys(stmts).length;
coveredStatements += Object.values(stmts).filter(v => v > 0).length;
}
const pct = totalStatements > 0 ? (coveredStatements / totalStatements * 100) : 0;
console.log(pct.toFixed(1));
")
echo "当前前端覆盖率: ${COVERAGE}%"
if [ "$(echo "$COVERAGE < 20" | bc -l 2>/dev/null || node -e "console.log($COVERAGE < 20 ? 1 : 0)")" = "1" ]; then
echo "::warning::前端覆盖率 ${COVERAGE}% 低于门禁值 20%(当前为警告,不阻塞)"
fi
# ==========================================================================
# Docker 构建验证
# ==========================================================================
docker-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Docker 构建验证
run: docker build -t aicodex2api:ci-test .
This diff is collapsed.
This diff is collapsed.
...@@ -14,6 +14,44 @@ import ( ...@@ -14,6 +14,44 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// ---------- 辅助函数 ----------
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response {
t.Helper()
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
return got
}
// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData)
func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) {
t.Helper()
// 先用 raw json 解析,因为 Data 是 any 类型
var raw struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
var pd PaginatedData
require.NoError(t, json.Unmarshal(raw.Data, &pd))
return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd
}
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
return w, c
}
// ---------- 现有测试 ----------
func TestErrorWithDetails(t *testing.T) { func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
...@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) { ...@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) {
}) })
} }
} }
// ---------- 新增测试 ----------
func TestSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
data any
wantCode int
wantBody Response
}{
{
name: "返回字符串数据",
data: "hello",
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success", Data: "hello"},
},
{
name: "返回nil数据",
data: nil,
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success"},
},
{
name: "返回map数据",
data: map[string]string{"key": "value"},
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Success(c, tt.data)
require.Equal(t, tt.wantCode, w.Code)
// 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice
got := parseResponseBody(t, w)
require.Equal(t, 0, got.Code)
require.Equal(t, "success", got.Message)
if tt.data == nil {
require.Nil(t, got.Data)
} else {
require.NotNil(t, got.Data)
}
})
}
}
func TestCreated(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
data any
wantCode int
}{
{
name: "创建成功_返回数据",
data: map[string]int{"id": 42},
wantCode: http.StatusCreated,
},
{
name: "创建成功_nil数据",
data: nil,
wantCode: http.StatusCreated,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Created(c, tt.data)
require.Equal(t, tt.wantCode, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, 0, got.Code)
require.Equal(t, "success", got.Message)
})
}
}
func TestError(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
}{
{
name: "400错误",
statusCode: http.StatusBadRequest,
message: "bad request",
},
{
name: "500错误",
statusCode: http.StatusInternalServerError,
message: "internal error",
},
{
name: "自定义状态码",
statusCode: 418,
message: "I'm a teapot",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Error(c, tt.statusCode, tt.message)
require.Equal(t, tt.statusCode, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, tt.statusCode, got.Code)
require.Equal(t, tt.message, got.Message)
require.Empty(t, got.Reason)
require.Nil(t, got.Metadata)
require.Nil(t, got.Data)
})
}
}
func TestBadRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
BadRequest(c, "参数无效")
require.Equal(t, http.StatusBadRequest, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusBadRequest, got.Code)
require.Equal(t, "参数无效", got.Message)
}
func TestUnauthorized(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Unauthorized(c, "未登录")
require.Equal(t, http.StatusUnauthorized, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusUnauthorized, got.Code)
require.Equal(t, "未登录", got.Message)
}
func TestForbidden(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Forbidden(c, "无权限")
require.Equal(t, http.StatusForbidden, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusForbidden, got.Code)
require.Equal(t, "无权限", got.Message)
}
func TestNotFound(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
NotFound(c, "资源不存在")
require.Equal(t, http.StatusNotFound, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusNotFound, got.Code)
require.Equal(t, "资源不存在", got.Message)
}
func TestInternalError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
InternalError(c, "服务器内部错误")
require.Equal(t, http.StatusInternalServerError, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusInternalServerError, got.Code)
require.Equal(t, "服务器内部错误", got.Message)
}
func TestPaginated(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
items any
total int64
page int
pageSize int
wantPages int
wantTotal int64
wantPage int
wantPageSize int
}{
{
name: "标准分页_多页",
items: []string{"a", "b"},
total: 25,
page: 1,
pageSize: 10,
wantPages: 3,
wantTotal: 25,
wantPage: 1,
wantPageSize: 10,
},
{
name: "总数刚好整除",
items: []string{"a"},
total: 20,
page: 2,
pageSize: 10,
wantPages: 2,
wantTotal: 20,
wantPage: 2,
wantPageSize: 10,
},
{
name: "总数为0_pages至少为1",
items: []string{},
total: 0,
page: 1,
pageSize: 10,
wantPages: 1,
wantTotal: 0,
wantPage: 1,
wantPageSize: 10,
},
{
name: "单页数据",
items: []int{1, 2, 3},
total: 3,
page: 1,
pageSize: 20,
wantPages: 1,
wantTotal: 3,
wantPage: 1,
wantPageSize: 20,
},
{
name: "总数为1",
items: []string{"only"},
total: 1,
page: 1,
pageSize: 10,
wantPages: 1,
wantTotal: 1,
wantPage: 1,
wantPageSize: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Paginated(c, tt.items, tt.total, tt.page, tt.pageSize)
require.Equal(t, http.StatusOK, w.Code)
resp, pd := parsePaginatedBody(t, w)
require.Equal(t, 0, resp.Code)
require.Equal(t, "success", resp.Message)
require.Equal(t, tt.wantTotal, pd.Total)
require.Equal(t, tt.wantPage, pd.Page)
require.Equal(t, tt.wantPageSize, pd.PageSize)
require.Equal(t, tt.wantPages, pd.Pages)
})
}
}
func TestPaginatedWithResult(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
items any
pagination *PaginationResult
wantTotal int64
wantPage int
wantPageSize int
wantPages int
}{
{
name: "正常分页结果",
items: []string{"a", "b"},
pagination: &PaginationResult{
Total: 50,
Page: 3,
PageSize: 10,
Pages: 5,
},
wantTotal: 50,
wantPage: 3,
wantPageSize: 10,
wantPages: 5,
},
{
name: "pagination为nil_使用默认值",
items: []string{},
pagination: nil,
wantTotal: 0,
wantPage: 1,
wantPageSize: 20,
wantPages: 1,
},
{
name: "单页结果",
items: []int{1},
pagination: &PaginationResult{
Total: 1,
Page: 1,
PageSize: 20,
Pages: 1,
},
wantTotal: 1,
wantPage: 1,
wantPageSize: 20,
wantPages: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
PaginatedWithResult(c, tt.items, tt.pagination)
require.Equal(t, http.StatusOK, w.Code)
resp, pd := parsePaginatedBody(t, w)
require.Equal(t, 0, resp.Code)
require.Equal(t, "success", resp.Message)
require.Equal(t, tt.wantTotal, pd.Total)
require.Equal(t, tt.wantPage, pd.Page)
require.Equal(t, tt.wantPageSize, pd.PageSize)
require.Equal(t, tt.wantPages, pd.Pages)
})
}
}
func TestParsePagination(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
query string
wantPage int
wantPageSize int
}{
{
name: "无参数_使用默认值",
query: "",
wantPage: 1,
wantPageSize: 20,
},
{
name: "仅指定page",
query: "page=3",
wantPage: 3,
wantPageSize: 20,
},
{
name: "仅指定page_size",
query: "page_size=50",
wantPage: 1,
wantPageSize: 50,
},
{
name: "同时指定page和page_size",
query: "page=2&page_size=30",
wantPage: 2,
wantPageSize: 30,
},
{
name: "使用limit代替page_size",
query: "limit=15",
wantPage: 1,
wantPageSize: 15,
},
{
name: "page_size优先于limit",
query: "page_size=25&limit=50",
wantPage: 1,
wantPageSize: 25,
},
{
name: "page为0_使用默认值",
query: "page=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size超过1000_使用默认值",
query: "page_size=1001",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size恰好1000_有效",
query: "page_size=1000",
wantPage: 1,
wantPageSize: 1000,
},
{
name: "page为非数字_使用默认值",
query: "page=abc",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size为非数字_使用默认值",
query: "page_size=xyz",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit为非数字_使用默认值",
query: "limit=abc",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size为0_使用默认值",
query: "page_size=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit为0_使用默认值",
query: "limit=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "大页码",
query: "page=999&page_size=100",
wantPage: 999,
wantPageSize: 100,
},
{
name: "page_size为1_最小有效值",
query: "page_size=1",
wantPage: 1,
wantPageSize: 1,
},
{
name: "混合数字和字母的page",
query: "page=12a",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit超过1000_使用默认值",
query: "limit=2000",
wantPage: 1,
wantPageSize: 20,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, c := newContextWithQuery(tt.query)
page, pageSize := ParsePagination(c)
require.Equal(t, tt.wantPage, page, "page 不符合预期")
require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期")
})
}
}
func Test_parseInt(t *testing.T) {
tests := []struct {
name string
input string
wantVal int
wantErr bool
}{
{
name: "正常数字",
input: "123",
wantVal: 123,
wantErr: false,
},
{
name: "零",
input: "0",
wantVal: 0,
wantErr: false,
},
{
name: "单个数字",
input: "5",
wantVal: 5,
wantErr: false,
},
{
name: "大数字",
input: "99999",
wantVal: 99999,
wantErr: false,
},
{
name: "包含字母_返回0",
input: "abc",
wantVal: 0,
wantErr: false,
},
{
name: "数字开头接字母_返回0",
input: "12a",
wantVal: 0,
wantErr: false,
},
{
name: "包含负号_返回0",
input: "-1",
wantVal: 0,
wantErr: false,
},
{
name: "包含小数点_返回0",
input: "1.5",
wantVal: 0,
wantErr: false,
},
{
name: "包含空格_返回0",
input: "1 2",
wantVal: 0,
wantErr: false,
},
{
name: "空字符串",
input: "",
wantVal: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := parseInt(tt.input)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantVal, val)
})
}
}
//go:build unit
package service
import (
"context"
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// --- mock: ClaudeOAuthClient ---
type mockClaudeOAuthClient struct {
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
}
func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
if m.getOrgUUIDFunc != nil {
return m.getOrgUUIDFunc(ctx, sessionKey, proxyURL)
}
panic("GetOrganizationUUID not implemented")
}
func (m *mockClaudeOAuthClient) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
if m.getAuthCodeFunc != nil {
return m.getAuthCodeFunc(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
}
panic("GetAuthorizationCode not implemented")
}
func (m *mockClaudeOAuthClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
if m.exchangeCodeFunc != nil {
return m.exchangeCodeFunc(ctx, code, codeVerifier, state, proxyURL, isSetupToken)
}
panic("ExchangeCodeForToken not implemented")
}
func (m *mockClaudeOAuthClient) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
if m.refreshTokenFunc != nil {
return m.refreshTokenFunc(ctx, refreshToken, proxyURL)
}
panic("RefreshToken not implemented")
}
// --- mock: ProxyRepository (最小实现,仅覆盖 OAuthService 依赖的方法) ---
type mockProxyRepoForOAuth struct {
getByIDFunc func(ctx context.Context, id int64) (*Proxy, error)
}
func (m *mockProxyRepoForOAuth) Create(ctx context.Context, proxy *Proxy) error {
panic("Create not implemented")
}
func (m *mockProxyRepoForOAuth) GetByID(ctx context.Context, id int64) (*Proxy, error) {
if m.getByIDFunc != nil {
return m.getByIDFunc(ctx, id)
}
return nil, fmt.Errorf("proxy not found")
}
func (m *mockProxyRepoForOAuth) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
panic("ListByIDs not implemented")
}
func (m *mockProxyRepoForOAuth) Update(ctx context.Context, proxy *Proxy) error {
panic("Update not implemented")
}
func (m *mockProxyRepoForOAuth) Delete(ctx context.Context, id int64) error {
panic("Delete not implemented")
}
func (m *mockProxyRepoForOAuth) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
panic("List not implemented")
}
func (m *mockProxyRepoForOAuth) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
panic("ListWithFilters not implemented")
}
func (m *mockProxyRepoForOAuth) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
panic("ListWithFiltersAndAccountCount not implemented")
}
func (m *mockProxyRepoForOAuth) ListActive(ctx context.Context) ([]Proxy, error) {
panic("ListActive not implemented")
}
func (m *mockProxyRepoForOAuth) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
panic("ListActiveWithAccountCount not implemented")
}
func (m *mockProxyRepoForOAuth) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
panic("ExistsByHostPortAuth not implemented")
}
func (m *mockProxyRepoForOAuth) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
panic("CountAccountsByProxyID not implemented")
}
func (m *mockProxyRepoForOAuth) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
panic("ListAccountSummariesByProxyID not implemented")
}
// =====================
// 测试用例
// =====================
func TestNewOAuthService(t *testing.T) {
t.Parallel()
proxyRepo := &mockProxyRepoForOAuth{}
client := &mockClaudeOAuthClient{}
svc := NewOAuthService(proxyRepo, client)
if svc == nil {
t.Fatal("NewOAuthService 返回 nil")
}
if svc.proxyRepo != proxyRepo {
t.Fatal("proxyRepo 未正确设置")
}
if svc.oauthClient != client {
t.Fatal("oauthClient 未正确设置")
}
if svc.sessionStore == nil {
t.Fatal("sessionStore 应被自动初始化")
}
// 清理
svc.Stop()
}
func TestOAuthService_GenerateAuthURL(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
result, err := svc.GenerateAuthURL(context.Background(), nil)
if err != nil {
t.Fatalf("GenerateAuthURL 返回错误: %v", err)
}
if result == nil {
t.Fatal("GenerateAuthURL 返回 nil")
}
if result.AuthURL == "" {
t.Fatal("AuthURL 为空")
}
if result.SessionID == "" {
t.Fatal("SessionID 为空")
}
// 验证 session 已存储
session, ok := svc.sessionStore.Get(result.SessionID)
if !ok {
t.Fatal("session 未在 sessionStore 中找到")
}
if session.Scope != oauth.ScopeOAuth {
t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeOAuth)
}
}
func TestOAuthService_GenerateAuthURL_WithProxy(t *testing.T) {
t.Parallel()
proxyRepo := &mockProxyRepoForOAuth{
getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) {
return &Proxy{
ID: 1,
Protocol: "http",
Host: "proxy.example.com",
Port: 8080,
}, nil
},
}
svc := NewOAuthService(proxyRepo, &mockClaudeOAuthClient{})
defer svc.Stop()
proxyID := int64(1)
result, err := svc.GenerateAuthURL(context.Background(), &proxyID)
if err != nil {
t.Fatalf("GenerateAuthURL 返回错误: %v", err)
}
session, ok := svc.sessionStore.Get(result.SessionID)
if !ok {
t.Fatal("session 未在 sessionStore 中找到")
}
if session.ProxyURL != "http://proxy.example.com:8080" {
t.Fatalf("ProxyURL 不匹配: got=%q", session.ProxyURL)
}
}
func TestOAuthService_GenerateSetupTokenURL(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
result, err := svc.GenerateSetupTokenURL(context.Background(), nil)
if err != nil {
t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err)
}
if result == nil {
t.Fatal("GenerateSetupTokenURL 返回 nil")
}
// 验证 scope 是 inference
session, ok := svc.sessionStore.Get(result.SessionID)
if !ok {
t.Fatal("session 未在 sessionStore 中找到")
}
if session.Scope != oauth.ScopeInference {
t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeInference)
}
}
func TestOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
_, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: "nonexistent-session",
Code: "test-code",
})
if err == nil {
t.Fatal("ExchangeCode 应返回错误(session 不存在)")
}
if err.Error() != "session not found or expired" {
t.Fatalf("错误信息不匹配: got=%q", err.Error())
}
}
func TestOAuthService_ExchangeCode_Success(t *testing.T) {
t.Parallel()
exchangeCalled := false
client := &mockClaudeOAuthClient{
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
exchangeCalled = true
if code != "auth-code-123" {
t.Errorf("code 不匹配: got=%q", code)
}
if isSetupToken {
t.Error("isSetupToken 应为 false(ScopeOAuth)")
}
return &oauth.TokenResponse{
AccessToken: "access-token-abc",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "refresh-token-xyz",
Scope: oauth.ScopeOAuth,
Organization: &oauth.OrgInfo{UUID: "org-uuid-111"},
Account: &oauth.AccountInfo{UUID: "acc-uuid-222", EmailAddress: "test@example.com"},
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
// 先生成 URL 以创建 session
result, err := svc.GenerateAuthURL(context.Background(), nil)
if err != nil {
t.Fatalf("GenerateAuthURL 返回错误: %v", err)
}
// 交换 code
tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: result.SessionID,
Code: "auth-code-123",
})
if err != nil {
t.Fatalf("ExchangeCode 返回错误: %v", err)
}
if !exchangeCalled {
t.Fatal("ExchangeCodeForToken 未被调用")
}
if tokenInfo.AccessToken != "access-token-abc" {
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
}
if tokenInfo.TokenType != "Bearer" {
t.Fatalf("TokenType 不匹配: got=%q", tokenInfo.TokenType)
}
if tokenInfo.RefreshToken != "refresh-token-xyz" {
t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken)
}
if tokenInfo.OrgUUID != "org-uuid-111" {
t.Fatalf("OrgUUID 不匹配: got=%q", tokenInfo.OrgUUID)
}
if tokenInfo.AccountUUID != "acc-uuid-222" {
t.Fatalf("AccountUUID 不匹配: got=%q", tokenInfo.AccountUUID)
}
if tokenInfo.EmailAddress != "test@example.com" {
t.Fatalf("EmailAddress 不匹配: got=%q", tokenInfo.EmailAddress)
}
if tokenInfo.ExpiresIn != 3600 {
t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn)
}
if tokenInfo.ExpiresAt == 0 {
t.Fatal("ExpiresAt 不应为 0")
}
// 验证 session 已被删除
_, ok := svc.sessionStore.Get(result.SessionID)
if ok {
t.Fatal("session 应在交换成功后被删除")
}
}
func TestOAuthService_ExchangeCode_SetupToken(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
if !isSetupToken {
t.Error("isSetupToken 应为 true(ScopeInference)")
}
return &oauth.TokenResponse{
AccessToken: "setup-token",
TokenType: "Bearer",
ExpiresIn: 3600,
Scope: oauth.ScopeInference,
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
// 使用 SetupToken URL(inference scope)
result, err := svc.GenerateSetupTokenURL(context.Background(), nil)
if err != nil {
t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err)
}
tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: result.SessionID,
Code: "setup-code",
})
if err != nil {
t.Fatalf("ExchangeCode 返回错误: %v", err)
}
if tokenInfo.AccessToken != "setup-token" {
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
}
}
func TestOAuthService_ExchangeCode_ClientError(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
return nil, fmt.Errorf("upstream error: invalid code")
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
result, _ := svc.GenerateAuthURL(context.Background(), nil)
_, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: result.SessionID,
Code: "bad-code",
})
if err == nil {
t.Fatal("ExchangeCode 应返回错误")
}
if err.Error() != "upstream error: invalid code" {
t.Fatalf("错误信息不匹配: got=%q", err.Error())
}
}
func TestOAuthService_RefreshToken(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
if refreshToken != "my-refresh-token" {
t.Errorf("refreshToken 不匹配: got=%q", refreshToken)
}
if proxyURL != "" {
t.Errorf("proxyURL 应为空: got=%q", proxyURL)
}
return &oauth.TokenResponse{
AccessToken: "new-access-token",
TokenType: "Bearer",
ExpiresIn: 7200,
RefreshToken: "new-refresh-token",
Scope: oauth.ScopeOAuth,
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
tokenInfo, err := svc.RefreshToken(context.Background(), "my-refresh-token", "")
if err != nil {
t.Fatalf("RefreshToken 返回错误: %v", err)
}
if tokenInfo.AccessToken != "new-access-token" {
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
}
if tokenInfo.RefreshToken != "new-refresh-token" {
t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken)
}
if tokenInfo.ExpiresIn != 7200 {
t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn)
}
if tokenInfo.ExpiresAt == 0 {
t.Fatal("ExpiresAt 不应为 0")
}
}
func TestOAuthService_RefreshToken_Error(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
return nil, fmt.Errorf("invalid_grant: token expired")
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
_, err := svc.RefreshToken(context.Background(), "expired-token", "")
if err == nil {
t.Fatal("RefreshToken 应返回错误")
}
}
func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
// 无 refresh_token 的账号
account := &Account{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
},
}
_, err := svc.RefreshAccountToken(context.Background(), account)
if err == nil {
t.Fatal("RefreshAccountToken 应返回错误(无 refresh_token)")
}
if err.Error() != "no refresh token available" {
t.Fatalf("错误信息不匹配: got=%q", err.Error())
}
}
func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
account := &Account{
ID: 2,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
"refresh_token": "",
},
}
_, err := svc.RefreshAccountToken(context.Background(), account)
if err == nil {
t.Fatal("RefreshAccountToken 应返回错误(refresh_token 为空)")
}
}
func TestOAuthService_RefreshAccountToken_Success(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
if refreshToken != "account-refresh-token" {
t.Errorf("refreshToken 不匹配: got=%q", refreshToken)
}
return &oauth.TokenResponse{
AccessToken: "refreshed-access",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "new-refresh",
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
account := &Account{
ID: 3,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-access",
"refresh_token": "account-refresh-token",
},
}
tokenInfo, err := svc.RefreshAccountToken(context.Background(), account)
if err != nil {
t.Fatalf("RefreshAccountToken 返回错误: %v", err)
}
if tokenInfo.AccessToken != "refreshed-access" {
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
}
}
func TestOAuthService_RefreshAccountToken_WithProxy(t *testing.T) {
t.Parallel()
proxyRepo := &mockProxyRepoForOAuth{
getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) {
return &Proxy{
Protocol: "socks5",
Host: "socks.example.com",
Port: 1080,
Username: "user",
Password: "pass",
}, nil
},
}
client := &mockClaudeOAuthClient{
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
if proxyURL != "socks5://user:pass@socks.example.com:1080" {
t.Errorf("proxyURL 不匹配: got=%q", proxyURL)
}
return &oauth.TokenResponse{
AccessToken: "refreshed",
ExpiresIn: 3600,
}, nil
},
}
svc := NewOAuthService(proxyRepo, client)
defer svc.Stop()
proxyID := int64(10)
account := &Account{
ID: 4,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
ProxyID: &proxyID,
Credentials: map[string]any{
"refresh_token": "rt-with-proxy",
},
}
_, err := svc.RefreshAccountToken(context.Background(), account)
if err != nil {
t.Fatalf("RefreshAccountToken 返回错误: %v", err)
}
}
func TestOAuthService_ExchangeCode_NilOrg(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
return &oauth.TokenResponse{
AccessToken: "token-no-org",
TokenType: "Bearer",
ExpiresIn: 3600,
Organization: nil,
Account: nil,
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
result, _ := svc.GenerateAuthURL(context.Background(), nil)
tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: result.SessionID,
Code: "code",
})
if err != nil {
t.Fatalf("ExchangeCode 返回错误: %v", err)
}
if tokenInfo.OrgUUID != "" {
t.Fatalf("OrgUUID 应为空: got=%q", tokenInfo.OrgUUID)
}
if tokenInfo.AccountUUID != "" {
t.Fatalf("AccountUUID 应为空: got=%q", tokenInfo.AccountUUID)
}
}
func TestOAuthService_Stop_NoPanic(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
// 调用 Stop 不应 panic
svc.Stop()
// 多次调用也不应 panic
svc.Stop()
}
package logredact
import (
"strings"
"testing"
)
func TestRedactText_JSONLike(t *testing.T) {
in := `{"access_token":"ya29.a0AfH6SMDUMMY","refresh_token":"1//0gDUMMY","other":"ok"}`
out := RedactText(in)
if out == in {
t.Fatalf("expected redaction, got unchanged")
}
if want := `"access_token":"***"`; !strings.Contains(out, want) {
t.Fatalf("expected %q in %q", want, out)
}
if want := `"refresh_token":"***"`; !strings.Contains(out, want) {
t.Fatalf("expected %q in %q", want, out)
}
}
func TestRedactText_QueryLike(t *testing.T) {
in := "access_token=ya29.a0AfH6SMDUMMY refresh_token=1//0gDUMMY"
out := RedactText(in)
if strings.Contains(out, "ya29") || strings.Contains(out, "1//0") {
t.Fatalf("expected tokens redacted, got %q", out)
}
}
func TestRedactText_GOCSPX(t *testing.T) {
in := "client_secret=GOCSPX-abcdefghijklmnopqrstuvwxyz_0123456789"
out := RedactText(in)
if strings.Contains(out, "abcdefghijklmnopqrstuvwxyz") {
t.Fatalf("expected secret redacted, got %q", out)
}
if !strings.Contains(out, "client_secret=***") {
t.Fatalf("expected key redacted, got %q", out)
}
}
...@@ -49,3 +49,27 @@ func TestValidateURLFormat(t *testing.T) { ...@@ -49,3 +49,27 @@ func TestValidateURLFormat(t *testing.T) {
t.Fatalf("expected trailing slash to be removed from path, got %s", normalized) t.Fatalf("expected trailing slash to be removed from path, got %s", normalized)
} }
} }
func TestValidateHTTPURL(t *testing.T) {
if _, err := ValidateHTTPURL("http://example.com", false, ValidationOptions{}); err == nil {
t.Fatalf("expected http to fail when allow_insecure_http is false")
}
if _, err := ValidateHTTPURL("http://example.com", true, ValidationOptions{}); err != nil {
t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err)
}
if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{RequireAllowlist: true}); err == nil {
t.Fatalf("expected require allowlist to fail when empty")
}
if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err == nil {
t.Fatalf("expected host not in allowlist to fail")
}
if _, err := ValidateHTTPURL("https://api.example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err != nil {
t.Fatalf("expected allowlisted host to pass, got %v", err)
}
if _, err := ValidateHTTPURL("https://sub.api.example.com", false, ValidationOptions{AllowedHosts: []string{"*.example.com"}}); err != nil {
t.Fatalf("expected wildcard allowlist to pass, got %v", err)
}
if _, err := ValidateHTTPURL("https://localhost", false, ValidationOptions{AllowPrivate: false}); err == nil {
t.Fatalf("expected localhost to be blocked when allow_private_hosts is false")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment