Commit 7319122e authored by LLLLLLiulei's avatar LLLLLLiulei
Browse files

merge upstream/main

parents 029994a8 4809fa4f
// Package googleapi provides helpers for Google-style API responses.
package googleapi
import (
"encoding/json"
"fmt"
"strings"
)
// ErrorResponse represents a Google API error response
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail contains the error details from Google API
type ErrorDetail struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
Details []json.RawMessage `json:"details,omitempty"`
}
// ErrorDetailInfo contains additional error information
type ErrorDetailInfo struct {
Type string `json:"@type"`
Reason string `json:"reason,omitempty"`
Domain string `json:"domain,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ErrorHelp contains help links
type ErrorHelp struct {
Type string `json:"@type"`
Links []HelpLink `json:"links,omitempty"`
}
// HelpLink represents a help link
type HelpLink struct {
Description string `json:"description"`
URL string `json:"url"`
}
// ParseError parses a Google API error response and extracts key information
func ParseError(body string) (*ErrorResponse, error) {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return &errResp, nil
}
// ExtractActivationURL extracts the API activation URL from error details
func ExtractActivationURL(body string) string {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return ""
}
// Check error details for activation URL
for _, detailRaw := range errResp.Error.Details {
// Parse as ErrorDetailInfo
var info ErrorDetailInfo
if err := json.Unmarshal(detailRaw, &info); err == nil {
if info.Metadata != nil {
if activationURL, ok := info.Metadata["activationUrl"]; ok && activationURL != "" {
return activationURL
}
}
}
// Parse as ErrorHelp
var help ErrorHelp
if err := json.Unmarshal(detailRaw, &help); err == nil {
for _, link := range help.Links {
if strings.Contains(link.Description, "activation") ||
strings.Contains(link.Description, "API activation") ||
strings.Contains(link.URL, "/apis/api/") {
return link.URL
}
}
}
}
return ""
}
// IsServiceDisabledError checks if the error is a SERVICE_DISABLED error
func IsServiceDisabledError(body string) bool {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return false
}
// Check if it's a 403 PERMISSION_DENIED with SERVICE_DISABLED reason
if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" {
return false
}
for _, detailRaw := range errResp.Error.Details {
var info ErrorDetailInfo
if err := json.Unmarshal(detailRaw, &info); err == nil {
if info.Reason == "SERVICE_DISABLED" {
return true
}
}
}
return false
}
package googleapi
import (
"testing"
)
func TestExtractActivationURL(t *testing.T) {
// Test case from the user's error message
errorBody := `{
"error": {
"code": 403,
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.",
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "SERVICE_DISABLED",
"domain": "googleapis.com",
"metadata": {
"service": "cloudaicompanion.googleapis.com",
"activationUrl": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843",
"consumer": "projects/project-6eca5881-ab73-4736-843",
"serviceTitle": "Gemini for Google Cloud API",
"containerInfo": "project-6eca5881-ab73-4736-843"
}
},
{
"@type": "type.googleapis.com/google.rpc.LocalizedMessage",
"locale": "en-US",
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry."
},
{
"@type": "type.googleapis.com/google.rpc.Help",
"links": [
{
"description": "Google developers console API activation",
"url": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
}
]
}
]
}
}`
activationURL := ExtractActivationURL(errorBody)
expectedURL := "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
if activationURL != expectedURL {
t.Errorf("Expected activation URL %s, got %s", expectedURL, activationURL)
}
}
func TestIsServiceDisabledError(t *testing.T) {
tests := []struct {
name string
body string
expected bool
}{
{
name: "SERVICE_DISABLED error",
body: `{
"error": {
"code": 403,
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "SERVICE_DISABLED"
}
]
}
}`,
expected: true,
},
{
name: "Other 403 error",
body: `{
"error": {
"code": 403,
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "OTHER_REASON"
}
]
}
}`,
expected: false,
},
{
name: "404 error",
body: `{
"error": {
"code": 404,
"status": "NOT_FOUND"
}
}`,
expected: false,
},
{
name: "Invalid JSON",
body: `invalid json`,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsServiceDisabledError(tt.body)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestParseError(t *testing.T) {
errorBody := `{
"error": {
"code": 403,
"message": "API not enabled",
"status": "PERMISSION_DENIED"
}
}`
errResp, err := ParseError(errorBody)
if err != nil {
t.Fatalf("Failed to parse error: %v", err)
}
if errResp.Error.Code != 403 {
t.Errorf("Expected code 403, got %d", errResp.Error.Code)
}
if errResp.Error.Status != "PERMISSION_DENIED" {
t.Errorf("Expected status PERMISSION_DENIED, got %s", errResp.Error.Status)
}
if errResp.Error.Message != "API not enabled" {
t.Errorf("Expected message 'API not enabled', got %s", errResp.Error.Message)
}
}
......@@ -15,6 +15,8 @@ type Model struct {
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
......
......@@ -1089,8 +1089,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
payload, id,
string(payload), id,
)
if err != nil {
return err
}
......
package repository
import (
"context"
"encoding/json"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
errorPassthroughCacheKey = "error_passthrough_rules"
errorPassthroughPubSubKey = "error_passthrough_rules_updated"
errorPassthroughCacheTTL = 24 * time.Hour
)
type errorPassthroughCache struct {
rdb *redis.Client
localCache []*model.ErrorPassthroughRule
localMu sync.RWMutex
}
// NewErrorPassthroughCache 创建错误透传规则缓存
func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache {
return &errorPassthroughCache{
rdb: rdb,
}
}
// Get 从缓存获取规则列表
func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
// 先检查本地缓存
c.localMu.RLock()
if c.localCache != nil {
rules := c.localCache
c.localMu.RUnlock()
return rules, true
}
c.localMu.RUnlock()
// 从 Redis 获取
data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes()
if err != nil {
if err != redis.Nil {
log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err)
}
return nil, false
}
var rules []*model.ErrorPassthroughRule
if err := json.Unmarshal(data, &rules); err != nil {
log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err)
return nil, false
}
// 更新本地缓存
c.localMu.Lock()
c.localCache = rules
c.localMu.Unlock()
return rules, true
}
// Set 设置缓存
func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
data, err := json.Marshal(rules)
if err != nil {
return err
}
if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil {
return err
}
// 更新本地缓存
c.localMu.Lock()
c.localCache = rules
c.localMu.Unlock()
return nil
}
// Invalidate 使缓存失效
func (c *errorPassthroughCache) Invalidate(ctx context.Context) error {
// 清除本地缓存
c.localMu.Lock()
c.localCache = nil
c.localMu.Unlock()
// 清除 Redis 缓存
return c.rdb.Del(ctx, errorPassthroughCacheKey).Err()
}
// NotifyUpdate 通知其他实例刷新缓存
func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error {
return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err()
}
// SubscribeUpdates 订阅缓存更新通知
func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
go func() {
sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey)
defer func() { _ = sub.Close() }()
ch := sub.Channel()
for {
select {
case <-ctx.Done():
return
case msg := <-ch:
if msg == nil {
return
}
// 清除本地缓存,下次访问时会从 Redis 或数据库重新加载
c.localMu.Lock()
c.localCache = nil
c.localMu.Unlock()
// 调用处理函数
handler()
}
}
}()
}
package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type errorPassthroughRepository struct {
client *ent.Client
}
// NewErrorPassthroughRepository 创建错误透传规则仓库
func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository {
return &errorPassthroughRepository{client: client}
}
// List 获取所有规则
func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
rules, err := r.client.ErrorPassthroughRule.Query().
Order(ent.Asc(errorpassthroughrule.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]*model.ErrorPassthroughRule, len(rules))
for i, rule := range rules {
result[i] = r.toModel(rule)
}
return result, nil
}
// GetByID 根据 ID 获取规则
func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
rule, err := r.client.ErrorPassthroughRule.Get(ctx, id)
if err != nil {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return r.toModel(rule), nil
}
// Create 创建规则
func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.Create().
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
}
created, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(created), nil
}
// Update 更新规则
func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID).
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
// 处理可选字段
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
} else {
builder.ClearErrorCodes()
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
} else {
builder.ClearKeywords()
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
} else {
builder.ClearPlatforms()
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
} else {
builder.ClearResponseCode()
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
} else {
builder.ClearCustomMessage()
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
} else {
builder.ClearDescription()
}
updated, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(updated), nil
}
// Delete 删除规则
func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error {
return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx)
}
// toModel 将 Ent 实体转换为服务模型
func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule {
rule := &model.ErrorPassthroughRule{
ID: int64(e.ID),
Name: e.Name,
Enabled: e.Enabled,
Priority: e.Priority,
ErrorCodes: e.ErrorCodes,
Keywords: e.Keywords,
MatchMode: e.MatchMode,
Platforms: e.Platforms,
PassthroughCode: e.PassthroughCode,
PassthroughBody: e.PassthroughBody,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
if e.ResponseCode != nil {
rule.ResponseCode = e.ResponseCode
}
if e.CustomMessage != nil {
rule.CustomMessage = e.CustomMessage
}
if e.Description != nil {
rule.Description = e.Description
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
return rule
}
......@@ -6,6 +6,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
......@@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
body := resp.String()
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
}
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil
......@@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
body := resp.String()
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
}
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil
......
......@@ -3,6 +3,7 @@ package repository
import (
"context"
"fmt"
"log"
"strconv"
"time"
......@@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv
if defaultIdleTimeoutMinutes <= 0 {
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
}
// 预加载 Lua 脚本到 Redis,避免 Pipeline 中出现 NOSCRIPT 错误
ctx := context.Background()
scripts := []*redis.Script{
registerSessionScript,
refreshSessionScript,
getActiveSessionCountScript,
isSessionActiveScript,
}
for _, script := range scripts {
if err := script.Load(ctx, rdb).Err(); err != nil {
log.Printf("[SessionLimitCache] Failed to preload Lua script: %v", err)
}
}
return &sessionLimitCache{
rdb: rdb,
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
......
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type userGroupRateRepository struct {
sql sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
// GetByUserID 获取用户的所有专属分组倍率
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make(map[int64]float64)
for rows.Next() {
var groupID int64
var rate float64
if err := rows.Scan(&groupID, &rate); err != nil {
return nil, err
}
result[groupID] = rate
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &rate, nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
// 如果传入空 map,删除该用户的所有专属倍率
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
toUpsert := make(map[int64]float64)
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
} else {
toUpsert[groupID] = *rate
}
}
// 删除指定的记录
for _, groupID := range toDelete {
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
userID, groupID)
if err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
for groupID, rate := range toUpsert {
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
`, userID, groupID, rate, now)
if err != nil {
return err
}
}
return nil
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
// DeleteByUserID 删除指定用户的所有专属倍率
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}
......@@ -66,6 +66,8 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository,
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
// Cache implementations
NewGatewayCache,
......@@ -86,6 +88,7 @@ var ProviderSet = wire.NewSet(
NewProxyLatencyCache,
NewTotpCache,
NewRefreshTokenCache,
NewErrorPassthroughCache,
// Encryptors
NewAESEncryptor,
......
......@@ -593,7 +593,7 @@ func newContractDeps(t *testing.T) *contractDeps {
}
userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
......@@ -607,7 +607,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
......
......@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache
&config.Config{},
)
......@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
nil,
nil,
nil,
nil,
&config.Config{RunMode: config.RunModeSimple},
)
......
......@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
......@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
......@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
......@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
......
......@@ -67,6 +67,9 @@ func RegisterAdminRoutes(
// 用户属性管理
registerUserAttributeRoutes(admin, h)
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
}
}
......@@ -391,3 +394,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
rules := admin.Group("/error-passthrough-rules")
{
rules.GET("", h.Admin.ErrorPassthrough.List)
rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID)
rules.POST("", h.Admin.ErrorPassthrough.Create)
rules.PUT("/:id", h.Admin.ErrorPassthrough.Update)
rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete)
}
}
......@@ -49,6 +49,7 @@ func RegisterUserRoutes(
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
groups.GET("/rates", h.APIKey.GetUserGroupRates)
}
// 使用记录
......
......@@ -94,6 +94,9 @@ type UpdateUserInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
}
type CreateGroupInput struct {
......@@ -296,6 +299,7 @@ type adminServiceImpl struct {
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
......@@ -310,6 +314,7 @@ func NewAdminService(
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
......@@ -322,6 +327,7 @@ func NewAdminService(
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
......@@ -336,11 +342,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil {
return nil, 0, err
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id)
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
} else {
user.GroupRates = rates
}
}
return user, nil
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
......@@ -409,6 +439,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
}
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
......@@ -944,6 +982,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
if err != nil {
return err
}
// 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
// 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
......
......@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
......@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
contentType := resp.Header.Get("Content-Type")
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
......@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
......
......@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
......@@ -132,16 +133,18 @@ func NewAPIKeyService(
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache,
cfg *config.Config,
) *APIKeyService {
svc := &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
cfg: cfg,
}
svc.initAuthCache(cfg)
return svc
......@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
return keys, nil
}
// GetUserGroupRates 获取用户的专属分组倍率配置
// 返回 map[groupID]rateMultiplier
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
if s.userGroupRateRepo == nil {
return nil, nil
}
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user group rates: %w", err)
}
return rates, nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
......
......@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
......@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil
}
......@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
......@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
L1TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1")
......@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2)
......@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2)
......@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1)
......@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
......@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
Singleflight: true,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
start := make(chan struct{})
wg := sync.WaitGroup{}
......
package service
import (
"context"
"log"
"sort"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
type ErrorPassthroughRepository interface {
// List 获取所有规则
List(ctx context.Context) ([]*model.ErrorPassthroughRule, error)
// GetByID 根据 ID 获取规则
GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error)
// Create 创建规则
Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Update 更新规则
Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Delete 删除规则
Delete(ctx context.Context, id int64) error
}
// ErrorPassthroughCache 定义错误透传规则的缓存接口
type ErrorPassthroughCache interface {
// Get 从缓存获取规则列表
Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool)
// Set 设置缓存
Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error
// Invalidate 使缓存失效
Invalidate(ctx context.Context) error
// NotifyUpdate 通知其他实例刷新缓存
NotifyUpdate(ctx context.Context) error
// SubscribeUpdates 订阅缓存更新通知
SubscribeUpdates(ctx context.Context, handler func())
}
// ErrorPassthroughService 错误透传规则服务
type ErrorPassthroughService struct {
repo ErrorPassthroughRepository
cache ErrorPassthroughCache
// 本地内存缓存,用于快速匹配
localCache []*model.ErrorPassthroughRule
localCacheMu sync.RWMutex
}
// NewErrorPassthroughService 创建错误透传规则服务
func NewErrorPassthroughService(
repo ErrorPassthroughRepository,
cache ErrorPassthroughCache,
) *ErrorPassthroughService {
svc := &ErrorPassthroughService{
repo: repo,
cache: cache,
}
// 启动时加载规则到本地缓存
ctx := context.Background()
if err := svc.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
}
// 订阅缓存更新通知
if cache != nil {
cache.SubscribeUpdates(ctx, func() {
if err := svc.refreshLocalCache(context.Background()); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
}
})
}
return svc
}
// List 获取所有规则
func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
return s.repo.List(ctx)
}
// GetByID 根据 ID 获取规则
func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
return s.repo.GetByID(ctx, id)
}
// Create 创建规则
func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
created, err := s.repo.Create(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return created, nil
}
// Update 更新规则
func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
updated, err := s.repo.Update(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return updated, nil
}
// Delete 删除规则
func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
if err := s.repo.Delete(ctx, id); err != nil {
return err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return nil
}
// MatchRule 匹配透传规则
// 返回第一个匹配的规则,如果没有匹配则返回 nil
func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule {
rules := s.getCachedRules()
if len(rules) == 0 {
return nil
}
bodyStr := strings.ToLower(string(body))
for _, rule := range rules {
if !rule.Enabled {
continue
}
if !s.platformMatches(rule, platform) {
continue
}
if s.ruleMatches(rule, statusCode, bodyStr) {
return rule
}
}
return nil
}
// getCachedRules 获取缓存的规则列表(按优先级排序)
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
s.localCacheMu.RLock()
rules := s.localCache
s.localCacheMu.RUnlock()
if rules != nil {
return rules
}
// 如果本地缓存为空,尝试刷新
ctx := context.Background()
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
return nil
}
s.localCacheMu.RLock()
defer s.localCacheMu.RUnlock()
return s.localCache
}
// refreshLocalCache 刷新本地缓存
func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
// 先尝试从 Redis 缓存获取
if s.cache != nil {
if rules, ok := s.cache.Get(ctx); ok {
s.setLocalCache(rules)
return nil
}
}
// 从数据库加载(repo.List 已按 priority 排序)
rules, err := s.repo.List(ctx)
if err != nil {
return err
}
// 更新 Redis 缓存
if s.cache != nil {
if err := s.cache.Set(ctx, rules); err != nil {
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
}
}
// 更新本地缓存(setLocalCache 内部会确保排序)
s.setLocalCache(rules)
return nil
}
// setLocalCache 设置本地缓存
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
// 按优先级排序
sorted := make([]*model.ErrorPassthroughRule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
s.localCacheMu.Lock()
s.localCache = sorted
s.localCacheMu.Unlock()
}
// invalidateAndNotify 使缓存失效并通知其他实例
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 刷新本地缓存
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
}
// 通知其他实例
if s.cache != nil {
if err := s.cache.NotifyUpdate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
}
}
}
// platformMatches 检查平台是否匹配
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
// 如果没有配置平台限制,则匹配所有平台
if len(rule.Platforms) == 0 {
return true
}
platform = strings.ToLower(platform)
for _, p := range rule.Platforms {
if strings.ToLower(p) == platform {
return true
}
}
return false
}
// ruleMatches 检查规则是否匹配
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
hasErrorCodes := len(rule.ErrorCodes) > 0
hasKeywords := len(rule.Keywords) > 0
// 如果没有配置任何条件,不匹配
if !hasErrorCodes && !hasKeywords {
return false
}
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
if rule.MatchMode == model.MatchModeAll {
// "all" 模式:所有配置的条件都必须满足
return codeMatch && keywordMatch
}
// "any" 模式:任一条件满足即可
if hasErrorCodes && hasKeywords {
return codeMatch || keywordMatch
}
return codeMatch && keywordMatch
}
// containsInt 检查切片是否包含指定整数
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
for _, v := range slice {
if v == val {
return true
}
}
return false
}
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
for _, kw := range keywords {
if strings.Contains(bodyLower, strings.ToLower(kw)) {
return true
}
}
return false
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment