Commit c5781c69 authored by IanShaw027's avatar IanShaw027
Browse files

fix(merge): 解决与 main 分支的配置冲突

- 合并 main 分支的上游错误日志配置
- 保留调度配置
- 合并 beta header 和 failover 配置
parents e1a9c1ec 34c10204
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -31,7 +30,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -31,7 +30,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
wire.Build( wire.Build(
// Infrastructure layer ProviderSets // Infrastructure layer ProviderSets
config.ProviderSet, config.ProviderSet,
infrastructure.ProviderSet,
// Business layer ProviderSets // Business layer ProviderSets
repository.ProviderSet, repository.ProviderSet,
......
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin" "github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -35,18 +34,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -35,18 +34,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
client, err := infrastructure.ProvideEnt(configConfig) client, err := repository.ProvideEnt(configConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db, err := infrastructure.ProvideSQLDB(client) db, err := repository.ProvideSQLDB(client)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userRepository := repository.NewUserRepository(client, db) userRepository := repository.NewUserRepository(client, db)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig) settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := infrastructure.ProvideRedis(configConfig) redisClient := repository.ProvideRedis(configConfig)
emailCache := repository.NewEmailCache(redisClient) emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache) emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier() turnstileVerifier := repository.NewTurnstileVerifier()
......
...@@ -121,6 +121,17 @@ type GatewayConfig struct { ...@@ -121,6 +121,17 @@ type GatewayConfig struct {
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期 // 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置 // Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
} }
...@@ -334,6 +345,10 @@ func setDefaults() { ...@@ -334,6 +345,10 @@ func setDefaults() {
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
......
package infrastructure
import (
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
entsql "entgo.io/ent/dialect/sql"
)
// ProviderSet 是基础设施层的 Wire 依赖提供者集合。
//
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
//
// 包含的提供者:
// - ProvideEnt: 提供 Ent ORM 客户端
// - ProvideSQLDB: 提供底层 SQL 数据库连接
// - ProvideRedis: 提供 Redis 客户端
var ProviderSet = wire.NewSet(
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖:config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存(如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖:config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}
...@@ -54,6 +54,9 @@ type CustomToolSpec struct { ...@@ -54,6 +54,9 @@ type CustomToolSpec struct {
InputSchema map[string]any `json:"input_schema"` InputSchema map[string]any `json:"input_schema"`
} }
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素 // SystemBlock system prompt 数组形式的元素
type SystemBlock struct { type SystemBlock struct {
Type string `json:"type"` Type string `json:"type"`
......
...@@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st ...@@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射 // 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround // 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents // 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil { if err != nil {
...@@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st ...@@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig // 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq) reqForGen := claudeReq
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
// 4. 构建 tools // 4. 构建 tools
tools := buildTools(claudeReq.Tools) tools := buildTools(claudeReq.Tools)
...@@ -150,6 +161,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT ...@@ -150,6 +161,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
parts = append([]GeminiPart{{ parts = append([]GeminiPart{{
Text: "Thinking...", Text: "Thinking...",
Thought: true, Thought: true,
ThoughtSignature: dummyThoughtSignature,
}}, parts...) }}, parts...)
} }
} }
...@@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT ...@@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator" const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts // buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
...@@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
} }
case "thinking": case "thinking":
part := GeminiPart{ if allowDummyThought {
// Gemini 模型可以使用 dummy signature
parts = append(parts, GeminiPart{
Text: block.Thinking, Text: block.Thinking,
Thought: true, Thought: true,
ThoughtSignature: dummyThoughtSignature,
})
continue
} }
// 保留原有 signature(Claude 模型需要有效的 signature)
if block.Signature != "" { // Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。
part.ThoughtSignature = block.Signature signature := strings.TrimSpace(block.Signature)
} else if !allowDummyThought { if signature == "" || signature == dummyThoughtSignature {
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
log.Printf("Warning: skipping thinking block without signature for Claude model")
continue continue
} else {
// Gemini 模型使用 dummy signature
part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part) if !isValidThoughtSignature(signature) {
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
case "image": case "image":
if block.Source != nil && block.Source.Type == "base64" { if block.Source != nil && block.Source.Type == "base64" {
...@@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID, ID: block.ID,
}, },
} }
// 保留原有 signature,或对 Gemini 模型使用 dummy signature // 只有 Gemini 模型使用 dummy signature
if block.Signature != "" { // Claude 模型不设置 signature(避免验证问题)
part.ThoughtSignature = block.Signature if allowDummyThought {
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part) parts = append(parts, part)
...@@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具 // 普通工具
var funcDecls []GeminiFunctionDecl var funcDecls []GeminiFunctionDecl
for _, tool := range tools { for i, tool := range tools {
// 跳过无效工具名称 // 跳过无效工具名称
if tool.Name == "" { if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name") log.Printf("Warning: skipping tool with empty name")
continue continue
} }
...@@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
var inputSchema map[string]any var inputSchema map[string]any
// 检查是否为 custom 类型工具 (MCP) // 检查是否为 custom 类型工具 (MCP)
if tool.Type == "custom" && tool.Custom != nil { if tool.Type == "custom" {
// Custom 格式: 从 custom 字段获取 description 和 input_schema if tool.Custom == nil || tool.Custom.InputSchema == nil {
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
continue
}
description = tool.Custom.Description description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else { } else {
// 标准格式: 从顶层字段获取 // 标准格式: 从顶层字段获取
description = tool.Description description = tool.Description
...@@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 清理 JSON Schema // 清理 JSON Schema
params := cleanJSONSchema(inputSchema) params := cleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值 // 为 nil schema 提供默认值
if params == nil { if params == nil {
params = map[string]any{ params = map[string]any{
...@@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
} }
} }
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{ funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name, Name: tool.Name,
Description: description, Description: description,
...@@ -479,24 +538,54 @@ func cleanJSONSchema(schema map[string]any) map[string]any { ...@@ -479,24 +538,54 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
} }
// excludedSchemaKeys 不支持的 schema 字段 // excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{ var excludedSchemaKeys = map[string]bool{
// 元 schema 字段
"$schema": true, "$schema": true,
"$id": true, "$id": true,
"$ref": true, "$ref": true,
"additionalProperties": true,
// 字符串验证(Gemini 不支持)
"minLength": true, "minLength": true,
"maxLength": true, "maxLength": true,
"minItems": true, "pattern": true,
"maxItems": true,
"uniqueItems": true, // 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true, "minimum": true,
"maximum": true, "maximum": true,
"exclusiveMinimum": true, "exclusiveMinimum": true,
"exclusiveMaximum": true, "exclusiveMaximum": true,
"pattern": true, "multipleOf": true,
"format": true,
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schema(Gemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties)
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true, "default": true,
"strict": true,
"const": true, "const": true,
"examples": true, "examples": true,
"deprecated": true, "deprecated": true,
...@@ -504,6 +593,9 @@ var excludedSchemaKeys = map[string]bool{ ...@@ -504,6 +593,9 @@ var excludedSchemaKeys = map[string]bool{
"writeOnly": true, "writeOnly": true,
"contentMediaType": true, "contentMediaType": true,
"contentEncoding": true, "contentEncoding": true,
// Claude 特有字段
"strict": true,
} }
// cleanSchemaValue 递归清理 schema 值 // cleanSchemaValue 递归清理 schema 值
...@@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any { ...@@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any {
continue continue
} }
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false(更安全的默认值)
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值 // 递归清理所有值
result[k] = cleanSchemaValue(val) result[k] = cleanSchemaValue(val)
} }
......
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留,thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &ClaudeCustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &ClaudeCustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &ClaudeCustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}
...@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav ...@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头 // Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{ var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)", "User-Agent": "claude-cli/2.0.62 (external, cli)",
......
...@@ -4,7 +4,7 @@ import ( ...@@ -4,7 +4,7 @@ import (
"math" "math"
"net/http" "net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
......
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -82,7 +82,7 @@ func TestErrorFrom(t *testing.T) { ...@@ -82,7 +82,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "application_error", name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}), err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusForbidden, wantHTTPCode: http.StatusForbidden,
wantBody: Response{ wantBody: Response{
...@@ -94,7 +94,7 @@ func TestErrorFrom(t *testing.T) { ...@@ -94,7 +94,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "bad_request_error", name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"), err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusBadRequest, wantHTTPCode: http.StatusBadRequest,
wantBody: Response{ wantBody: Response{
...@@ -105,7 +105,7 @@ func TestErrorFrom(t *testing.T) { ...@@ -105,7 +105,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "unauthorized_error", name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"), err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusUnauthorized, wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{ wantBody: Response{
...@@ -116,7 +116,7 @@ func TestErrorFrom(t *testing.T) { ...@@ -116,7 +116,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "not_found_error", name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"), err: errors2.NotFound("NOT_FOUND", "not found"),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusNotFound, wantHTTPCode: http.StatusNotFound,
wantBody: Response{ wantBody: Response{
...@@ -127,7 +127,7 @@ func TestErrorFrom(t *testing.T) { ...@@ -127,7 +127,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "conflict_error", name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"), err: errors2.Conflict("CONFLICT", "conflict"),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusConflict, wantHTTPCode: http.StatusConflict,
wantBody: Response{ wantBody: Response{
...@@ -143,7 +143,7 @@ func TestErrorFrom(t *testing.T) { ...@@ -143,7 +143,7 @@ func TestErrorFrom(t *testing.T) {
wantHTTPCode: http.StatusInternalServerError, wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{ wantBody: Response{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage, Message: errors2.UnknownMessage,
}, },
}, },
} }
......
package infrastructure package repository
import ( import (
"database/sql" "database/sql"
......
package infrastructure package repository
import ( import (
"database/sql" "database/sql"
......
// Package infrastructure 提供应用程序的基础设施层组件。 // Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。 // 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package infrastructure package repository
import ( import (
"context" "context"
......
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
"strings" "strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/lib/pq" "github.com/lib/pq"
) )
......
...@@ -17,7 +17,6 @@ import ( ...@@ -17,7 +17,6 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
_ "github.com/Wei-Shaw/sub2api/ent/runtime" _ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
...@@ -97,7 +96,7 @@ func TestMain(m *testing.M) { ...@@ -97,7 +96,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open sql db: %v", err) log.Printf("failed to open sql db: %v", err)
os.Exit(1) os.Exit(1)
} }
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil { if err := ApplyMigrations(ctx, integrationDB); err != nil {
log.Printf("failed to apply db migrations: %v", err) log.Printf("failed to apply db migrations: %v", err)
os.Exit(1) os.Exit(1)
} }
......
package infrastructure package repository
import ( import (
"context" "context"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment