Unverified Commit fa68cbad authored by InCerryGit's avatar InCerryGit Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents 995ef134 0f033930
...@@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
return nil return nil
} }
func (r *accountRepository) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
_, err := r.client.Account.UpdateOneID(id).
SetCredentials(normalizeJSONMap(credentials)).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
groupIDs, err := r.loadAccountGroupIDs(ctx, id) groupIDs, err := r.loadAccountGroupIDs(ctx, id)
if err != nil { if err != nil {
...@@ -443,10 +454,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { ...@@ -443,10 +454,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
} }
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "", 0) return r.ListWithFilters(ctx, params, "", "", "", "", 0, "")
} }
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
q := r.client.Account.Query() q := r.client.Account.Query()
if platform != "" { if platform != "" {
...@@ -479,6 +490,20 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati ...@@ -479,6 +490,20 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
} else if groupID > 0 { } else if groupID > 0 {
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID))) q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
} }
if privacyMode != "" {
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
path := sqljson.Path("privacy_mode")
switch privacyMode {
case service.AccountPrivacyModeUnsetFilter:
s.Where(entsql.Or(
entsql.Not(sqljson.HasKey(dbaccount.FieldExtra, path)),
sqljson.ValueEQ(dbaccount.FieldExtra, "", path),
))
default:
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, privacyMode, path))
}
}))
}
total, err := q.Count(ctx) total, err := q.Count(ctx)
if err != nil { if err != nil {
......
...@@ -208,15 +208,16 @@ func (s *AccountRepoSuite) TestList() { ...@@ -208,15 +208,16 @@ func (s *AccountRepoSuite) TestList() {
func (s *AccountRepoSuite) TestListWithFilters() { func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct { tests := []struct {
name string name string
setup func(client *dbent.Client) setup func(client *dbent.Client)
platform string platform string
accType string accType string
status string status string
search string search string
groupID int64 groupID int64
wantCount int privacyMode string
validate func(accounts []service.Account) wantCount int
validate func(accounts []service.Account)
}{ }{
{ {
name: "filter_by_platform", name: "filter_by_platform",
...@@ -281,6 +282,32 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -281,6 +282,32 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Require().Empty(accounts[0].GroupIDs) s.Require().Empty(accounts[0].GroupIDs)
}, },
}, },
{
name: "filter_by_privacy_mode",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-ok", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}})
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-fail", Extra: map[string]any{"privacy_mode": service.PrivacyModeFailed}})
},
privacyMode: service.PrivacyModeTrainingOff,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("privacy-ok", accounts[0].Name)
},
},
{
name: "filter_by_privacy_mode_unset",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-unset", Extra: nil})
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-empty", Extra: map[string]any{"privacy_mode": ""}})
mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-set", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}})
},
privacyMode: service.AccountPrivacyModeUnsetFilter,
wantCount: 2,
validate: func(accounts []service.Account) {
names := []string{accounts[0].Name, accounts[1].Name}
s.ElementsMatch([]string{"privacy-unset", "privacy-empty"}, names)
},
},
} }
for _, tt := range tests { for _, tt := range tests {
...@@ -293,7 +320,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -293,7 +320,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
tt.setup(client) tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID) accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID, tt.privacyMode)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount) s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil { if tt.validate != nil {
...@@ -360,7 +387,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { ...@@ -360,7 +387,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID) s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0) accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0, "")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)
......
...@@ -29,6 +29,11 @@ INSERT INTO ops_error_logs ( ...@@ -29,6 +29,11 @@ INSERT INTO ops_error_logs (
model, model,
request_path, request_path,
stream, stream,
inbound_endpoint,
upstream_endpoint,
requested_model,
upstream_model,
request_type,
user_agent, user_agent,
error_phase, error_phase,
error_type, error_type,
...@@ -57,7 +62,7 @@ INSERT INTO ops_error_logs ( ...@@ -57,7 +62,7 @@ INSERT INTO ops_error_logs (
retry_count, retry_count,
created_at created_at
) VALUES ( ) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43
)` )`
func NewOpsRepository(db *sql.DB) service.OpsRepository { func NewOpsRepository(db *sql.DB) service.OpsRepository {
...@@ -140,6 +145,11 @@ func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any { ...@@ -140,6 +145,11 @@ func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
opsNullString(input.Model), opsNullString(input.Model),
opsNullString(input.RequestPath), opsNullString(input.RequestPath),
input.Stream, input.Stream,
opsNullString(input.InboundEndpoint),
opsNullString(input.UpstreamEndpoint),
opsNullString(input.RequestedModel),
opsNullString(input.UpstreamModel),
opsNullInt16(input.RequestType),
opsNullString(input.UserAgent), opsNullString(input.UserAgent),
input.ErrorPhase, input.ErrorPhase,
input.ErrorType, input.ErrorType,
...@@ -231,7 +241,12 @@ SELECT ...@@ -231,7 +241,12 @@ SELECT
COALESCE(g.name, ''), COALESCE(g.name, ''),
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''), COALESCE(e.request_path, ''),
e.stream e.stream,
COALESCE(e.inbound_endpoint, ''),
COALESCE(e.upstream_endpoint, ''),
COALESCE(e.requested_model, ''),
COALESCE(e.upstream_model, ''),
e.request_type
FROM ops_error_logs e FROM ops_error_logs e
LEFT JOIN accounts a ON e.account_id = a.id LEFT JOIN accounts a ON e.account_id = a.id
LEFT JOIN groups g ON e.group_id = g.id LEFT JOIN groups g ON e.group_id = g.id
...@@ -263,6 +278,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) ...@@ -263,6 +278,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
var resolvedBy sql.NullInt64 var resolvedBy sql.NullInt64
var resolvedByName string var resolvedByName string
var resolvedRetryID sql.NullInt64 var resolvedRetryID sql.NullInt64
var requestType sql.NullInt64
if err := rows.Scan( if err := rows.Scan(
&item.ID, &item.ID,
&item.CreatedAt, &item.CreatedAt,
...@@ -294,6 +310,11 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) ...@@ -294,6 +310,11 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
&clientIP, &clientIP,
&item.RequestPath, &item.RequestPath,
&item.Stream, &item.Stream,
&item.InboundEndpoint,
&item.UpstreamEndpoint,
&item.RequestedModel,
&item.UpstreamModel,
&requestType,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
...@@ -334,6 +355,10 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) ...@@ -334,6 +355,10 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
item.GroupID = &v item.GroupID = &v
} }
item.GroupName = groupName item.GroupName = groupName
if requestType.Valid {
v := int16(requestType.Int64)
item.RequestType = &v
}
out = append(out, &item) out = append(out, &item)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
...@@ -393,6 +418,11 @@ SELECT ...@@ -393,6 +418,11 @@ SELECT
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
COALESCE(e.request_path, ''), COALESCE(e.request_path, ''),
e.stream, e.stream,
COALESCE(e.inbound_endpoint, ''),
COALESCE(e.upstream_endpoint, ''),
COALESCE(e.requested_model, ''),
COALESCE(e.upstream_model, ''),
e.request_type,
COALESCE(e.user_agent, ''), COALESCE(e.user_agent, ''),
e.auth_latency_ms, e.auth_latency_ms,
e.routing_latency_ms, e.routing_latency_ms,
...@@ -427,6 +457,7 @@ LIMIT 1` ...@@ -427,6 +457,7 @@ LIMIT 1`
var responseLatency sql.NullInt64 var responseLatency sql.NullInt64
var ttft sql.NullInt64 var ttft sql.NullInt64
var requestBodyBytes sql.NullInt64 var requestBodyBytes sql.NullInt64
var requestType sql.NullInt64
err := r.db.QueryRowContext(ctx, q, id).Scan( err := r.db.QueryRowContext(ctx, q, id).Scan(
&out.ID, &out.ID,
...@@ -464,6 +495,11 @@ LIMIT 1` ...@@ -464,6 +495,11 @@ LIMIT 1`
&clientIP, &clientIP,
&out.RequestPath, &out.RequestPath,
&out.Stream, &out.Stream,
&out.InboundEndpoint,
&out.UpstreamEndpoint,
&out.RequestedModel,
&out.UpstreamModel,
&requestType,
&out.UserAgent, &out.UserAgent,
&authLatency, &authLatency,
&routingLatency, &routingLatency,
...@@ -540,6 +576,10 @@ LIMIT 1` ...@@ -540,6 +576,10 @@ LIMIT 1`
v := int(requestBodyBytes.Int64) v := int(requestBodyBytes.Int64)
out.RequestBodyBytes = &v out.RequestBodyBytes = &v
} }
if requestType.Valid {
v := int16(requestType.Int64)
out.RequestType = &v
}
// Normalize request_body to empty string when stored as JSON null. // Normalize request_body to empty string when stored as JSON null.
out.RequestBody = strings.TrimSpace(out.RequestBody) out.RequestBody = strings.TrimSpace(out.RequestBody)
...@@ -1479,3 +1519,10 @@ func opsNullInt(v any) any { ...@@ -1479,3 +1519,10 @@ func opsNullInt(v any) any {
return sql.NullInt64{} return sql.NullInt64{}
} }
} }
func opsNullInt16(v *int16) any {
if v == nil {
return sql.NullInt64{}
}
return sql.NullInt64{Int64: int64(*v), Valid: true}
}
...@@ -990,7 +990,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination ...@@ -990,7 +990,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
......
...@@ -69,12 +69,30 @@ func RegisterGatewayRoutes( ...@@ -69,12 +69,30 @@ func RegisterGatewayRoutes(
}) })
gateway.GET("/models", h.Gateway.Models) gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage) gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API // OpenAI Responses API: auto-route based on group platform
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", func(c *gin.Context) {
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses) if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Responses(c)
return
}
h.Gateway.Responses(c)
})
gateway.POST("/responses/*subpath", func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Responses(c)
return
}
h.Gateway.Responses(c)
})
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API // OpenAI Chat Completions API: auto-route based on group platform
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions) gateway.POST("/chat/completions", func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.ChatCompletions(c)
return
}
h.Gateway.ChatCompletions(c)
})
} }
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
...@@ -92,12 +110,25 @@ func RegisterGatewayRoutes( ...@@ -92,12 +110,25 @@ func RegisterGatewayRoutes(
gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
} }
// OpenAI Responses API(不带v1前缀的别名) // OpenAI Responses API(不带v1前缀的别名)— auto-route based on group platform
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) responsesHandler := func(c *gin.Context) {
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Responses(c)
return
}
h.Gateway.Responses(c)
}
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API(不带v1前缀的别名) // OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.ChatCompletions(c)
return
}
h.Gateway.ChatCompletions(c)
})
// Antigravity 模型列表 // Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
......
package service
import "context"
type accountCredentialsUpdater interface {
UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error
}
func persistAccountCredentials(ctx context.Context, repo AccountRepository, account *Account, credentials map[string]any) error {
if repo == nil || account == nil {
return nil
}
account.Credentials = cloneCredentials(credentials)
if updater, ok := any(repo).(accountCredentialsUpdater); ok {
return updater.UpdateCredentials(ctx, account.ID, account.Credentials)
}
return repo.Update(ctx, account)
}
func cloneCredentials(in map[string]any) map[string]any {
if in == nil {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
...@@ -15,6 +15,7 @@ var ( ...@@ -15,6 +15,7 @@ var (
) )
const AccountListGroupUngrouped int64 = -1 const AccountListGroupUngrouped int64 = -1
const AccountPrivacyModeUnsetFilter = "__unset__"
type AccountRepository interface { type AccountRepository interface {
Create(ctx context.Context, account *Account) error Create(ctx context.Context, account *Account) error
...@@ -37,7 +38,7 @@ type AccountRepository interface { ...@@ -37,7 +38,7 @@ type AccountRepository interface {
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error) ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error)
......
...@@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination ...@@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
panic("unexpected List call") panic("unexpected List call")
} }
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") panic("unexpected ListWithFilters call")
} }
......
...@@ -54,7 +54,7 @@ type AdminService interface { ...@@ -54,7 +54,7 @@ type AdminService interface {
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
...@@ -1451,9 +1451,9 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou ...@@ -1451,9 +1451,9 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
......
...@@ -19,18 +19,20 @@ type accountRepoStubForAdminList struct { ...@@ -19,18 +19,20 @@ type accountRepoStubForAdminList struct {
listWithFiltersType string listWithFiltersType string
listWithFiltersStatus string listWithFiltersStatus string
listWithFiltersSearch string listWithFiltersSearch string
listWithFiltersPrivacy string
listWithFiltersAccounts []Account listWithFiltersAccounts []Account
listWithFiltersResult *pagination.PaginationResult listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error listWithFiltersErr error
} }
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++ s.listWithFiltersCalls++
s.listWithFiltersParams = params s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform s.listWithFiltersPlatform = platform
s.listWithFiltersType = accountType s.listWithFiltersType = accountType
s.listWithFiltersStatus = status s.listWithFiltersStatus = status
s.listWithFiltersSearch = search s.listWithFiltersSearch = search
s.listWithFiltersPrivacy = privacyMode
if s.listWithFiltersErr != nil { if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr return nil, nil, s.listWithFiltersErr
...@@ -168,7 +170,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { ...@@ -168,7 +170,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0) accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(10), total) require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
...@@ -182,6 +184,22 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { ...@@ -182,6 +184,22 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
}) })
} }
func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
t.Run("privacy_mode 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{
listWithFiltersAccounts: []Account{{ID: 2, Name: "acc2"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked)
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
require.Equal(t, PrivacyModeCFBlocked, repo.listWithFiltersPrivacy)
})
}
func TestAdminService_ListProxies_WithSearch(t *testing.T) { func TestAdminService_ListProxies_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{ repo := &proxyRepoStubForAdminList{
......
...@@ -643,6 +643,7 @@ urlFallbackLoop: ...@@ -643,6 +643,7 @@ urlFallbackLoop:
AccountID: p.account.ID, AccountID: p.account.ID,
AccountName: p.account.Name, AccountName: p.account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
}) })
...@@ -720,6 +721,7 @@ urlFallbackLoop: ...@@ -720,6 +721,7 @@ urlFallbackLoop:
AccountName: p.account.Name, AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry", Kind: "retry",
Message: upstreamMsg, Message: upstreamMsg,
Detail: getUpstreamDetail(respBody), Detail: getUpstreamDetail(respBody),
...@@ -754,6 +756,7 @@ urlFallbackLoop: ...@@ -754,6 +756,7 @@ urlFallbackLoop:
AccountName: p.account.Name, AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
Kind: "retry", Kind: "retry",
Message: upstreamMsg, Message: upstreamMsg,
Detail: getUpstreamDetail(respBody), Detail: getUpstreamDetail(respBody),
......
...@@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
p.markBackfillAttempted(account.ID) p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { if updateErr := persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials); updateErr != nil {
slog.Warn("antigravity_project_id_backfill_persist_failed", slog.Warn("antigravity_project_id_backfill_persist_failed",
"account_id", account.ID, "account_id", account.ID,
"error", updateErr, "error", updateErr,
......
...@@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after creation // 🔄 Refresh OAuth token after creation
if targetType == AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
_ = s.accountRepo.Update(ctx, account)
} }
} }
item.Action = "created" item.Action = "created"
...@@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update // 🔄 Refresh OAuth token after update
if targetType == AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
_ = s.accountRepo.Update(ctx, existing)
} }
} }
...@@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
// 🔄 Refresh OAuth token after creation // 🔄 Refresh OAuth token after creation
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
_ = s.accountRepo.Update(ctx, account)
} }
item.Action = "created" item.Action = "created"
result.Created++ result.Created++
...@@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
// 🔄 Refresh OAuth token after update // 🔄 Refresh OAuth token after update
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
_ = s.accountRepo.Update(ctx, existing)
} }
item.Action = "updated" item.Action = "updated"
...@@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
_ = s.accountRepo.Update(ctx, account)
} }
item.Action = "created" item.Action = "created"
result.Created++ result.Created++
...@@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
_ = s.accountRepo.Update(ctx, existing)
} }
item.Action = "updated" item.Action = "updated"
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"net/smtp" "net/smtp"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...@@ -111,7 +112,7 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { ...@@ -111,7 +112,7 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err) return nil, fmt.Errorf("get smtp settings: %w", err)
} }
host := settings[SettingKeySMTPHost] host := strings.TrimSpace(settings[SettingKeySMTPHost])
if host == "" { if host == "" {
return nil, ErrEmailNotConfigured return nil, ErrEmailNotConfigured
} }
...@@ -128,10 +129,10 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { ...@@ -128,10 +129,10 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
return &SMTPConfig{ return &SMTPConfig{
Host: host, Host: host,
Port: port, Port: port,
Username: settings[SettingKeySMTPUsername], Username: strings.TrimSpace(settings[SettingKeySMTPUsername]),
Password: settings[SettingKeySMTPPassword], Password: strings.TrimSpace(settings[SettingKeySMTPPassword]),
From: settings[SettingKeySMTPFrom], From: strings.TrimSpace(settings[SettingKeySMTPFrom]),
FromName: settings[SettingKeySMTPFromName], FromName: strings.TrimSpace(settings[SettingKeySMTPFromName]),
UseTLS: useTLS, UseTLS: useTLS,
}, nil }, nil
} }
......
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ForwardAsChatCompletions accepts an OpenAI Chat Completions API request body,
// converts it to Anthropic Messages format (chained via Responses format),
// forwards to the Anthropic upstream, and converts the response back to Chat
// Completions format. This enables Chat Completions clients to access Anthropic
// models through Anthropic platform groups.
func (s *GatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var ccReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &ccReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := ccReq.Model
clientStream := ccReq.Stream
includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage
// 2. Convert CC → Responses → Anthropic (chained conversion)
responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq)
if err != nil {
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
}
// 3. Force upstream streaming
anthropicReq.Stream = true
reqStream := true
// 4. Model mapping
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
}
}
anthropicReq.Model = mappedModel
logger.L().Debug("gateway forward_as_chat_completions: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("client_stream", clientStream),
)
// 5. Marshal Anthropic request body
anthropicBody, err := json.Marshal(anthropicReq)
if err != nil {
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts
isClaudeCode := false // CC API is never Claude Code
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
}
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 12. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
}
}
writeGatewayCCError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// 13. Extract reasoning effort from CC request body
reasoningEffort := extractCCReasoningEffortFromBody(body)
// 14. Handle normal response
// Read Anthropic SSE → convert to Responses events → convert to CC format
var result *ForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleCCStreamingFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime, includeUsage)
} else {
result, handleErr = s.handleCCBufferedFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
}
return result, handleErr
}
// extractCCReasoningEffortFromBody reads reasoning effort from a Chat Completions
// request body. It checks both nested (reasoning.effort) and flat (reasoning_effort)
// formats used by OpenAI-compatible clients.
func extractCCReasoningEffortFromBody(body []byte) *string {
raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if raw == "" {
raw = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
}
if raw == "" {
return nil
}
normalized := normalizeOpenAIReasoningEffort(raw)
if normalized == "" {
return nil
}
return &normalized
}
// handleCCBufferedFromAnthropic reads Anthropic SSE events, assembles the full
// response, then converts Anthropic → Responses → Chat Completions.
func (s *GatewayService) handleCCBufferedFromAnthropic(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResp *apicompat.AnthropicResponse
var usage ClaudeUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
// message_start carries the initial response structure and cache usage
if event.Type == "message_start" && event.Message != nil {
finalResp = event.Message
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// message_delta carries final usage and stop_reason
if event.Type == "message_delta" {
if event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil {
finalResp.StopReason = event.Delta.StopReason
}
}
if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil {
finalResp.Content = append(finalResp.Content, *event.ContentBlock)
}
if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil {
idx := *event.Index
if idx < len(finalResp.Content) {
switch event.Delta.Type {
case "text_delta":
finalResp.Content[idx].Text += event.Delta.Text
case "thinking_delta":
finalResp.Content[idx].Thinking += event.Delta.Thinking
case "input_json_delta":
finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON)
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_cc buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResp == nil {
writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response")
return nil, fmt.Errorf("upstream stream ended without response")
}
// Update usage from accumulated delta
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
finalResp.Usage = apicompat.AnthropicUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
}
}
// Chain: Anthropic → Responses → Chat Completions
responsesResp := apicompat.AnthropicToResponsesResponse(finalResp)
ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, ccResp)
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleCCStreamingFromAnthropic reads Anthropic SSE events, converts each
// to Responses events, then to Chat Completions chunks, and writes them.
func (s *GatewayService) handleCCStreamingFromAnthropic(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
includeUsage bool,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
// Use Anthropic→Responses state machine, then convert Responses→CC
anthState := apicompat.NewAnthropicEventToResponsesState()
anthState.Model = originalModel
ccState := apicompat.NewResponsesEventToChatState()
ccState.Model = originalModel
ccState.IncludeUsage = includeUsage
var usage ClaudeUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *ForwardResult {
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
writeChunk := func(chunk apicompat.ChatCompletionsChunk) bool {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
return false
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
return true // client disconnected
}
return false
}
processAnthropicEvent := func(event *apicompat.AnthropicStreamEvent) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// Extract usage from message_delta
if event.Type == "message_delta" && event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
// Also capture usage from message_start (carries cache fields)
if event.Type == "message_start" && event.Message != nil {
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// Chain: Anthropic event → Responses events → CC chunks
responsesEvents := apicompat.AnthropicEventToResponsesEvents(event, anthState)
for _, resEvt := range responsesEvents {
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
for _, chunk := range ccChunks {
if disconnected := writeChunk(chunk); disconnected {
return true
}
}
}
c.Writer.Flush()
return false
}
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
continue
}
if processAnthropicEvent(&event) {
return resultWithUsage(), nil
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_cc stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// Finalize both state machines
finalResEvents := apicompat.FinalizeAnthropicResponsesStream(anthState)
for _, resEvt := range finalResEvents {
ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
for _, chunk := range ccChunks {
writeChunk(chunk) //nolint:errcheck
}
}
finalCCChunks := apicompat.FinalizeResponsesChatStream(ccState)
for _, chunk := range finalCCChunks {
writeChunk(chunk) //nolint:errcheck
}
// Write [DONE] marker
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
c.Writer.Flush()
return resultWithUsage(), nil
}
// writeGatewayCCError writes an error in OpenAI Chat Completions format for
// the Anthropic-upstream CC forwarding path.
func writeGatewayCCError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
//go:build unit
package service
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractCCReasoningEffortFromBody(t *testing.T) {
t.Parallel()
t.Run("nested reasoning.effort", func(t *testing.T) {
got := extractCCReasoningEffortFromBody([]byte(`{"reasoning":{"effort":"HIGH"}}`))
require.NotNil(t, got)
require.Equal(t, "high", *got)
})
t.Run("flat reasoning_effort", func(t *testing.T) {
got := extractCCReasoningEffortFromBody([]byte(`{"reasoning_effort":"x-high"}`))
require.NotNil(t, got)
require.Equal(t, "xhigh", *got)
})
t.Run("missing effort", func(t *testing.T) {
require.Nil(t, extractCCReasoningEffortFromBody([]byte(`{"model":"gpt-5"}`)))
})
}
func TestHandleCCBufferedFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reasoningEffort := "high"
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_cc_buffered"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleCCBufferedFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 9, result.Usage.CacheReadInputTokens)
require.Equal(t, 3, result.Usage.CacheCreationInputTokens)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "high", *result.ReasoningEffort)
}
func TestHandleCCStreamingFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reasoningEffort := "medium"
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_cc_stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`,
``,
`event: message_stop`,
`data: {"type":"message_stop"}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleCCStreamingFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now(), true)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 20, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 11, result.Usage.CacheReadInputTokens)
require.Equal(t, 4, result.Usage.CacheCreationInputTokens)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "medium", *result.ReasoningEffort)
require.Contains(t, rec.Body.String(), `[DONE]`)
}
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ForwardAsResponses accepts an OpenAI Responses API request body, converts it
// to Anthropic Messages format, forwards to the Anthropic upstream, and converts
// the response back to Responses format. This enables OpenAI Responses API
// clients to access Anthropic models through Anthropic platform groups.
//
// The method follows the same pattern as OpenAIGatewayService.ForwardAsAnthropic
// but in reverse direction: Responses → Anthropic upstream → Responses.
func (s *GatewayService) ForwardAsResponses(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// 1. Parse Responses request
var responsesReq apicompat.ResponsesRequest
if err := json.Unmarshal(body, &responsesReq); err != nil {
return nil, fmt.Errorf("parse responses request: %w", err)
}
originalModel := responsesReq.Model
clientStream := responsesReq.Stream
// 2. Convert Responses → Anthropic
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(&responsesReq)
if err != nil {
return nil, fmt.Errorf("convert responses to anthropic: %w", err)
}
// 3. Force upstream streaming (Anthropic works best with streaming)
anthropicReq.Stream = true
reqStream := true
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
}
}
anthropicReq.Model = mappedModel
logger.L().Debug("gateway forward_as_responses: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("client_stream", clientStream),
)
// 5. Marshal Anthropic request body
anthropicBody, err := json.Marshal(anthropicReq)
if err != nil {
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
// 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints)
isClaudeCode := false // Responses API is never Claude Code
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
!systemIncludesClaudeCodePrompt(anthropicReq.System) {
anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
}
}
// 7. Enforce cache_control block limit
anthropicBody = enforceCacheControlLimit(anthropicBody)
// 8. Get access token
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 9. Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 10. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// 11. Send request
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 12. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
}
}
// Non-failover error: return Responses-formatted error to client
writeResponsesError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// 13. Handle normal response (convert Anthropic → Responses)
var result *ForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleResponsesStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
} else {
result, handleErr = s.handleResponsesBufferedStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime)
}
return result, handleErr
}
// ExtractResponsesReasoningEffortFromBody reads Responses API reasoning.effort
// and normalizes it for usage logging.
func ExtractResponsesReasoningEffortFromBody(body []byte) *string {
raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if raw == "" {
return nil
}
normalized := normalizeOpenAIReasoningEffort(raw)
if normalized == "" {
return nil
}
return &normalized
}
func mergeAnthropicUsage(dst *ClaudeUsage, src apicompat.AnthropicUsage) {
if dst == nil {
return
}
if src.InputTokens > 0 {
dst.InputTokens = src.InputTokens
}
if src.OutputTokens > 0 {
dst.OutputTokens = src.OutputTokens
}
if src.CacheReadInputTokens > 0 {
dst.CacheReadInputTokens = src.CacheReadInputTokens
}
if src.CacheCreationInputTokens > 0 {
dst.CacheCreationInputTokens = src.CacheCreationInputTokens
}
}
// handleResponsesBufferedStreamingResponse reads all Anthropic SSE events from
// the upstream streaming response, assembles them into a complete Anthropic
// response, converts to Responses API JSON format, and writes it to the client.
func (s *GatewayService) handleResponsesBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
// Accumulate the final Anthropic response from streaming events
var finalResp *apicompat.AnthropicResponse
var usage ClaudeUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
eventType := strings.TrimPrefix(line, "event: ")
// Read the data line
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("forward_as_responses buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
zap.String("event_type", eventType),
)
continue
}
// message_start carries the initial response structure
if event.Type == "message_start" && event.Message != nil {
finalResp = event.Message
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// message_delta carries final usage and stop_reason
if event.Type == "message_delta" {
if event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil {
finalResp.StopReason = event.Delta.StopReason
}
}
// Accumulate content blocks
if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil {
finalResp.Content = append(finalResp.Content, *event.ContentBlock)
}
if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil {
idx := *event.Index
if idx < len(finalResp.Content) {
switch event.Delta.Type {
case "text_delta":
finalResp.Content[idx].Text += event.Delta.Text
case "thinking_delta":
finalResp.Content[idx].Thinking += event.Delta.Thinking
case "input_json_delta":
finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON)
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_responses buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResp == nil {
writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response")
return nil, fmt.Errorf("upstream stream ended without response")
}
// Update usage from accumulated delta
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
finalResp.Usage = apicompat.AnthropicUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
}
}
// Convert to Responses format
responsesResp := apicompat.AnthropicToResponsesResponse(finalResp)
responsesResp.Model = originalModel // Use original model name
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, responsesResp)
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleResponsesStreamingResponse reads Anthropic SSE events from upstream,
// converts each to Responses SSE events, and writes them to the client.
func (s *GatewayService) handleResponsesStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
reasoningEffort *string,
startTime time.Time,
) (*ForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewAnthropicEventToResponsesState()
state.Model = originalModel
var usage ClaudeUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *ForwardResult {
return &ForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ReasoningEffort: reasoningEffort,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
// processEvent handles a single parsed Anthropic SSE event.
processEvent := func(event *apicompat.AnthropicStreamEvent) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// Extract usage from message_delta
if event.Type == "message_delta" && event.Usage != nil {
mergeAnthropicUsage(&usage, *event.Usage)
}
// Also capture usage from message_start
if event.Type == "message_start" && event.Message != nil {
mergeAnthropicUsage(&usage, event.Message.Usage)
}
// Convert to Responses events
events := apicompat.AnthropicEventToResponsesEvents(event, state)
for _, evt := range events {
sse, err := apicompat.ResponsesEventToSSE(evt)
if err != nil {
logger.L().Warn("forward_as_responses stream: failed to marshal event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("forward_as_responses stream: client disconnected",
zap.String("request_id", requestID),
)
return true // client disconnected
}
}
if len(events) > 0 {
c.Writer.Flush()
}
return false
}
finalizeStream := func() (*ForwardResult, error) {
if finalEvents := apicompat.FinalizeAnthropicResponsesStream(state); len(finalEvents) > 0 {
for _, evt := range finalEvents {
sse, err := apicompat.ResponsesEventToSSE(evt)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
c.Writer.Flush()
}
return resultWithUsage(), nil
}
// Read Anthropic SSE events
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "event: ") {
continue
}
eventType := strings.TrimPrefix(line, "event: ")
// Read data line
if !scanner.Scan() {
break
}
dataLine := scanner.Text()
if !strings.HasPrefix(dataLine, "data: ") {
continue
}
payload := dataLine[6:]
var event apicompat.AnthropicStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("forward_as_responses stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
zap.String("event_type", eventType),
)
continue
}
if processEvent(&event) {
return resultWithUsage(), nil
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("forward_as_responses stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
return finalizeStream()
}
// appendRawJSON appends a JSON fragment string to existing raw JSON.
func appendRawJSON(existing json.RawMessage, fragment string) json.RawMessage {
if len(existing) == 0 {
return json.RawMessage(fragment)
}
return json.RawMessage(string(existing) + fragment)
}
// writeResponsesError writes an error response in OpenAI Responses API format.
func writeResponsesError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"code": code,
"message": message,
},
})
}
// mapUpstreamStatusCode maps upstream HTTP status codes to appropriate client-facing codes.
func mapUpstreamStatusCode(code int) int {
if code >= 500 {
return http.StatusBadGateway
}
return code
}
//go:build unit
package service
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractResponsesReasoningEffortFromBody(t *testing.T) {
t.Parallel()
got := ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5","reasoning":{"effort":"HIGH"}}`))
require.NotNil(t, got)
require.Equal(t, "high", *got)
require.Nil(t, ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5"}`)))
}
func TestHandleResponsesBufferedStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_buffered"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleResponsesBufferedStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
require.Equal(t, 7, result.Usage.OutputTokens)
require.Equal(t, 9, result.Usage.CacheReadInputTokens)
require.Equal(t, 3, result.Usage.CacheCreationInputTokens)
require.Contains(t, rec.Body.String(), `"cached_tokens":9`)
}
func TestHandleResponsesStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
resp := &http.Response{
Header: http.Header{"x-request-id": []string{"rid_stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`,
``,
`event: content_block_start`,
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`,
``,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`,
``,
`event: message_stop`,
`data: {"type":"message_stop"}`,
``,
}, "\n"))),
}
svc := &GatewayService{}
result, err := svc.handleResponsesStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 20, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 11, result.Usage.CacheReadInputTokens)
require.Equal(t, 4, result.Usage.CacheCreationInputTokens)
require.Contains(t, rec.Body.String(), `response.completed`)
}
...@@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error ...@@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
......
...@@ -5,6 +5,8 @@ import ( ...@@ -5,6 +5,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math" "math"
"regexp"
"sort"
"strings" "strings"
"unsafe" "unsafe"
...@@ -34,6 +36,9 @@ var ( ...@@ -34,6 +36,9 @@ var (
patternEmptyTextSpaced = []byte(`"text": ""`) patternEmptyTextSpaced = []byte(`"text": ""`)
patternEmptyTextSp1 = []byte(`"text" : ""`) patternEmptyTextSp1 = []byte(`"text" : ""`)
patternEmptyTextSp2 = []byte(`"text" :""`) patternEmptyTextSp2 = []byte(`"text" :""`)
sessionUserAgentProductPattern = regexp.MustCompile(`([A-Za-z0-9._-]+)/[A-Za-z0-9._-]+`)
sessionUserAgentVersionPattern = regexp.MustCompile(`\bv?\d+(?:\.\d+){1,3}\b`)
) )
// SessionContext 粘性会话上下文,用于区分不同来源的请求。 // SessionContext 粘性会话上下文,用于区分不同来源的请求。
...@@ -75,6 +80,49 @@ type ParsedRequest struct { ...@@ -75,6 +80,49 @@ type ParsedRequest struct {
OnUpstreamAccepted func() OnUpstreamAccepted func()
} }
// NormalizeSessionUserAgent reduces UA noise for sticky-session and digest hashing.
// It preserves the set of product names from Product/Version tokens while
// discarding version-only changes and incidental comments.
func NormalizeSessionUserAgent(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
matches := sessionUserAgentProductPattern.FindAllStringSubmatch(raw, -1)
if len(matches) == 0 {
return normalizeSessionUserAgentFallback(raw)
}
products := make([]string, 0, len(matches))
seen := make(map[string]struct{}, len(matches))
for _, match := range matches {
if len(match) < 2 {
continue
}
product := strings.ToLower(strings.TrimSpace(match[1]))
if product == "" {
continue
}
if _, exists := seen[product]; exists {
continue
}
seen[product] = struct{}{}
products = append(products, product)
}
if len(products) == 0 {
return normalizeSessionUserAgentFallback(raw)
}
sort.Strings(products)
return strings.Join(products, "+")
}
func normalizeSessionUserAgentFallback(raw string) string {
normalized := strings.ToLower(strings.Join(strings.Fields(raw), " "))
normalized = sessionUserAgentVersionPattern.ReplaceAllString(normalized, "")
return strings.Join(strings.Fields(normalized), " ")
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果。 // ParseGatewayRequest 解析网关请求体并返回结构化结果。
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), // protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
// 不同协议使用不同的 system/messages 字段名。 // 不同协议使用不同的 system/messages 字段名。
...@@ -205,6 +253,118 @@ func sliceRawFromBody(body []byte, r gjson.Result) []byte { ...@@ -205,6 +253,118 @@ func sliceRawFromBody(body []byte, r gjson.Result) []byte {
return []byte(r.Raw) return []byte(r.Raw)
} }
// stripEmptyTextBlocksFromSlice removes empty text blocks from a content slice (including nested tool_result content).
// Returns (cleaned slice, true) if any blocks were removed, or (original, false) if unchanged.
func stripEmptyTextBlocksFromSlice(blocks []any) ([]any, bool) {
var result []any
changed := false
for i, block := range blocks {
blockMap, ok := block.(map[string]any)
if !ok {
if result != nil {
result = append(result, block)
}
continue
}
blockType, _ := blockMap["type"].(string)
// Strip empty text blocks
if blockType == "text" {
if txt, _ := blockMap["text"].(string); txt == "" {
if result == nil {
result = make([]any, 0, len(blocks))
result = append(result, blocks[:i]...)
}
changed = true
continue
}
}
// Recurse into tool_result nested content
if blockType == "tool_result" {
if nestedContent, ok := blockMap["content"].([]any); ok {
if cleaned, nestedChanged := stripEmptyTextBlocksFromSlice(nestedContent); nestedChanged {
if result == nil {
result = make([]any, 0, len(blocks))
result = append(result, blocks[:i]...)
}
changed = true
blockCopy := make(map[string]any, len(blockMap))
for k, v := range blockMap {
blockCopy[k] = v
}
blockCopy["content"] = cleaned
result = append(result, blockCopy)
continue
}
}
}
if result != nil {
result = append(result, block)
}
}
if !changed {
return blocks, false
}
return result, true
}
// StripEmptyTextBlocks removes empty text blocks from the request body (including nested tool_result content).
// This is a lightweight pre-filter for the initial request path to prevent upstream 400 errors.
// Returns the original body unchanged if no empty text blocks are found.
func StripEmptyTextBlocks(body []byte) []byte {
// Fast path: check if body contains empty text patterns
hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) ||
bytes.Contains(body, patternEmptyTextSpaced) ||
bytes.Contains(body, patternEmptyTextSp1) ||
bytes.Contains(body, patternEmptyTextSp2)
if !hasEmptyTextBlock {
return body
}
jsonStr := *(*string)(unsafe.Pointer(&body))
msgsRes := gjson.Get(jsonStr, "messages")
if !msgsRes.Exists() || !msgsRes.IsArray() {
return body
}
var messages []any
if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil {
return body
}
modified := false
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
if cleaned, changed := stripEmptyTextBlocksFromSlice(content); changed {
modified = true
msgMap["content"] = cleaned
}
}
if !modified {
return body
}
msgsBytes, err := json.Marshal(messages)
if err != nil {
return body
}
out, err := sjson.SetRawBytes(body, "messages", msgsBytes)
if err != nil {
return body
}
return out
}
// FilterThinkingBlocks removes thinking blocks from request body // FilterThinkingBlocks removes thinking blocks from request body
// Returns filtered body or original body if filtering fails (fail-safe) // Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures // This prevents 400 errors from invalid thinking block signatures
...@@ -378,6 +538,23 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { ...@@ -378,6 +538,23 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
} }
} }
// Recursively strip empty text blocks from tool_result nested content.
if blockType == "tool_result" {
if nestedContent, ok := blockMap["content"].([]any); ok {
if cleaned, changed := stripEmptyTextBlocksFromSlice(nestedContent); changed {
modifiedThisMsg = true
ensureNewContent(bi)
blockCopy := make(map[string]any, len(blockMap))
for k, v := range blockMap {
blockCopy[k] = v
}
blockCopy["content"] = cleaned
newContent = append(newContent, blockCopy)
continue
}
}
}
if newContent != nil { if newContent != nil {
newContent = append(newContent, block) newContent = append(newContent, block)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment