Commit 3b7a5fff authored by 陈曦's avatar 陈曦
Browse files

补充openai、gemini以及流失请求的采集数据以及nfs落库

parent 8519a8eb
Pipeline #82284 failed with stage
in 2 minutes and 21 seconds
......@@ -64,9 +64,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
}
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
h.requestCaptureService.Capture(
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
......@@ -144,6 +145,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err != nil {
reqLog.Warn("openai_chat_completions.account_select_failed",
......@@ -167,6 +169,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)
......@@ -277,6 +280,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil && result != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
......
......@@ -116,7 +116,7 @@ func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Set(opsModelKey, "gpt-5.3-codex")
c.Set(opsAccountIDKey, int64(123))
c.Header("x-request-id", "rid-compact-ok")
......@@ -142,7 +142,7 @@ func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Status(http.StatusBadGateway)
h := &OpenAIGatewayHandler{}
......@@ -180,7 +180,7 @@ func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
h := &OpenAIGatewayHandler{}
h.Responses(c)
......
......@@ -140,9 +140,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// 异步捕获请求体(仅当该 API Key 开启了 capture_requests)
var captureID int64
if apiKey.CaptureRequests && h.requestCaptureService != nil {
requestID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
h.requestCaptureService.Capture(
captureID = h.requestCaptureService.Capture(
apiKey.ID, subject.UserID,
requestID,
c.Request.URL.Path,
......@@ -255,6 +256,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
requireCompact := isOpenAIRemoteCompactPath(c)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
......@@ -273,6 +275,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
requireCompact,
)
if err != nil {
reqLog.Warn("openai.account_select_failed",
......@@ -280,6 +283,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
return
}
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
......@@ -400,6 +407,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil && result != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
......@@ -661,6 +673,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
currentRoutingModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
false,
)
if err != nil {
reqLog.Warn("openai_messages.account_select_failed",
......@@ -774,6 +787,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
// 异步写入响应体到捕获记录
if captureID > 0 && h.requestCaptureService != nil && result != nil {
h.requestCaptureService.CaptureResponse(captureID, result.ResponseBody)
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
......@@ -1184,6 +1202,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
......
......@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return
}
sessionHash := ""
if parsed.Multipart {
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
} else {
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
}
sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
......
......@@ -117,7 +117,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
Save(context.Background())
require.NoError(t, err)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
......@@ -215,7 +215,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
......@@ -302,7 +302,7 @@ func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *t
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
......@@ -342,7 +342,7 @@ func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
......
......@@ -75,5 +75,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
AffiliateEnabled: settings.AffiliateEnabled,
})
}
......@@ -14,10 +14,11 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
affiliateService *service.AffiliateService
}
// NewUserHandler creates a new UserHandler
......@@ -26,12 +27,14 @@ func NewUserHandler(
authService *service.AuthService,
emailService *service.EmailService,
emailCache service.EmailCache,
affiliateService *service.AffiliateService,
) *UserHandler {
return &UserHandler{
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
affiliateService: affiliateService,
}
}
......@@ -159,6 +162,44 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, profileResp)
}
// GetAffiliate returns the current user's affiliate details.
// GET /api/v1/user/aff
func (h *UserHandler) GetAffiliate(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, detail)
}
// TransferAffiliateQuota transfers all available affiliate quota into current balance.
// POST /api/v1/user/aff/transfer
func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"transferred_quota": transferred,
"balance": balance,
})
}
type StartIdentityBindingRequest struct {
Provider string `json:"provider" binding:"required"`
RedirectTo string `json:"redirect_to"`
......
......@@ -142,7 +142,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder()
......@@ -200,7 +200,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
......@@ -283,7 +283,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
......@@ -362,7 +362,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
......@@ -511,8 +511,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
recorder := httptest.NewRecorder()
......@@ -566,7 +566,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
......@@ -625,8 +625,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
......@@ -668,8 +668,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
......@@ -712,8 +712,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
recorder := httptest.NewRecorder()
......@@ -750,7 +750,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
recorder := httptest.NewRecorder()
......
......@@ -37,6 +37,7 @@ func ProvideAdminHandlers(
channelMonitorHandler *admin.ChannelMonitorHandler,
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
paymentHandler *admin.PaymentHandler,
affiliateHandler *admin.AffiliateHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
......@@ -67,6 +68,7 @@ func ProvideAdminHandlers(
ChannelMonitor: channelMonitorHandler,
ChannelMonitorTemplate: channelMonitorTemplateHandler,
Payment: paymentHandler,
Affiliate: affiliateHandler,
}
}
......@@ -169,6 +171,7 @@ var ProviderSet = wire.NewSet(
admin.NewChannelMonitorHandler,
admin.NewChannelMonitorRequestTemplateHandler,
admin.NewPaymentHandler,
admin.NewAffiliateHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
......
......@@ -25,6 +25,7 @@ const (
easypayStatusPaid = 1
easypayHTTPTimeout = 10 * time.Second
maxEasypayResponseSize = 1 << 20 // 1MB
maxEasypayErrorSummary = 512
tradeStatusSuccess = "TRADE_SUCCESS"
signTypeMD5 = "MD5"
paymentModePopup = "popup"
......@@ -42,17 +43,55 @@ type EasyPay struct {
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
if config[k] == "" {
if strings.TrimSpace(config[k]) == "" {
return nil, fmt.Errorf("easypay config missing required key: %s", k)
}
}
cfg := make(map[string]string, len(config))
for k, v := range config {
cfg[k] = v
}
cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
return &EasyPay{
instanceID: instanceID,
config: config,
config: cfg,
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
}, nil
}
func normalizeEasyPayAPIBase(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return ""
}
if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.RawPath = ""
parsed.Path = trimEasyPayEndpointPath(parsed.Path)
return strings.TrimRight(parsed.String(), "/")
}
return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
}
func trimEasyPayEndpointPath(path string) string {
path = strings.TrimRight(strings.TrimSpace(path), "/")
lower := strings.ToLower(path)
for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
if strings.HasSuffix(lower, endpoint) {
return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
}
}
return path
}
func (e *EasyPay) apiBase() string {
if e == nil {
return ""
}
return normalizeEasyPayAPIBase(e.config["apiBase"])
}
func (e *EasyPay) Name() string { return "EasyPay" }
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
......@@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
for k, v := range params {
q.Set(k, v)
}
base := strings.TrimRight(e.config["apiBase"], "/")
payURL := base + "/submit.php?" + q.Encode()
payURL := e.apiBase() + "/submit.php?" + q.Encode()
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
}
......@@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5
body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
if err != nil {
return nil, fmt.Errorf("easypay create: %w", err)
}
......@@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
"act": "order", "pid": e.config["pid"],
"key": e.config["pkey"], "out_trade_no": tradeNo,
}
body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
body, err := e.post(ctx, e.apiBase()+"/api.php", params)
if err != nil {
return nil, fmt.Errorf("easypay query: %w", err)
}
......@@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
}
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
params := map[string]string{
"pid": e.config["pid"], "key": e.config["pkey"],
"trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
attempts := e.refundAttempts(req)
if len(attempts) == 0 {
return nil, fmt.Errorf("easypay refund missing order identifier")
}
var firstErr error
for i, attempt := range attempts {
body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
if err != nil {
return nil, fmt.Errorf("easypay refund request: %w", err)
}
if err := parseEasyPayRefundResponse(status, body); err != nil {
if firstErr == nil {
firstErr = err
}
if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
continue
}
return nil, err
}
return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
}
body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
if err != nil {
return nil, fmt.Errorf("easypay refund: %w", err)
return nil, firstErr
}
type easyPayRefundAttempt struct {
params map[string]string
refundID string
}
func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
base := map[string]string{
"pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
}
var attempts []easyPayRefundAttempt
if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
params := cloneStringMap(base)
params["out_trade_no"] = orderID
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
}
if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
params := cloneStringMap(base)
params["trade_no"] = tradeNo
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
}
return attempts
}
func cloneStringMap(in map[string]string) map[string]string {
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func isEasyPayRefundOrderNotFound(err error) bool {
if err == nil {
return false
}
msg := err.Error()
lower := strings.ToLower(msg)
return strings.Contains(msg, "订单编号不存在") ||
strings.Contains(msg, "订单不存在") ||
strings.Contains(lower, "order not found") ||
strings.Contains(lower, "not exist")
}
func parseEasyPayRefundResponse(status int, body []byte) error {
summary := summarizeEasyPayResponse(body)
if status < http.StatusOK || status >= http.StatusMultipleChoices {
return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
}
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
}
lower := strings.ToLower(trimmed)
if strings.HasPrefix(lower, "<!doctype html") || strings.HasPrefix(lower, "<html") ||
(strings.HasPrefix(lower, "<") && strings.Contains(lower, "html")) {
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
}
var resp struct {
Code int `json:"code"`
Code any `json:"code"`
Msg string `json:"msg"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("easypay parse refund: %w", err)
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
}
if resp.Code != easypayCodeSuccess {
return nil, fmt.Errorf("easypay refund failed: %s", resp.Msg)
if !easyPayResponseCodeIsSuccess(resp.Code) {
msg := strings.TrimSpace(resp.Msg)
if msg == "" {
msg = summary
}
return fmt.Errorf("easypay refund failed (HTTP %d): %s", status, msg)
}
return nil
}
func easyPayResponseCodeIsSuccess(code any) bool {
switch v := code.(type) {
case float64:
return int(v) == easypayCodeSuccess
case string:
n, err := strconv.Atoi(strings.TrimSpace(v))
return err == nil && n == easypayCodeSuccess
default:
return false
}
}
func summarizeEasyPayResponse(body []byte) string {
summary := strings.Join(strings.Fields(string(body)), " ")
if summary == "" {
return "<empty>"
}
if len(summary) > maxEasypayErrorSummary {
return summary[:maxEasypayErrorSummary] + "..."
}
return &payment.RefundResponse{RefundID: req.TradeNo, Status: payment.ProviderStatusSuccess}, nil
return summary
}
func (e *EasyPay) resolveCID(paymentType string) string {
......@@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string {
}
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
body, _, err := e.postRaw(ctx, endpoint, params)
return body, err
}
func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
return nil, 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := e.httpClient.Do(req)
client := e.httpClient
if client == nil {
client = &http.Client{Timeout: easypayHTTPTimeout}
}
resp, err := client.Do(req)
if err != nil {
return nil, err
return nil, 0, err
}
defer func() { _ = resp.Body.Close() }()
return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
if err != nil {
return nil, resp.StatusCode, err
}
return body, resp.StatusCode, nil
}
func easyPaySign(params map[string]string, pkey string) string {
......
package provider
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeEasyPayAPIBase(t *testing.T) {
t.Parallel()
tests := []struct {
input string
want string
}{
{input: "https://zpayz.cn", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
t.Parallel()
var gotPath string
var gotQuery url.Values
var gotForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotQuery = r.URL.Query()
if err := r.ParseForm(); err != nil {
t.Errorf("ParseForm: %v", err)
}
gotForm = r.PostForm
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL+"/mapi.php")
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
TradeNo: "trade-123",
OrderID: "out-456",
Amount: "1.50",
})
if err != nil {
t.Fatalf("Refund returned error: %v", err)
}
if resp == nil || resp.Status != payment.ProviderStatusSuccess {
t.Fatalf("Refund response = %+v, want success", resp)
}
if gotPath != "/api.php" {
t.Fatalf("refund path = %q, want /api.php", gotPath)
}
if gotQuery.Get("act") != "refund" {
t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
}
for key, want := range map[string]string{
"pid": "pid-1",
"key": "pkey-1",
"out_trade_no": "out-456",
"money": "1.50",
} {
if got := gotForm.Get(key); got != want {
t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
}
}
if got := gotForm.Get("trade_no"); got != "" {
t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
}
}
func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
t.Parallel()
var gotForms []url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api.php" {
t.Errorf("refund path = %q, want /api.php", r.URL.Path)
}
if r.URL.Query().Get("act") != "refund" {
t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
}
if err := r.ParseForm(); err != nil {
t.Errorf("ParseForm: %v", err)
}
gotForms = append(gotForms, r.PostForm)
w.Header().Set("Content-Type", "application/json")
if len(gotForms) == 1 {
_, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
return
}
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL+"/mapi.php")
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
TradeNo: "trade-123",
OrderID: "out-456",
Amount: "1.50",
})
if err != nil {
t.Fatalf("Refund returned error: %v", err)
}
if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
}
if len(gotForms) != 2 {
t.Fatalf("refund attempts = %d, want 2", len(gotForms))
}
if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
}
if got := gotForms[0].Get("trade_no"); got != "" {
t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
}
if got := gotForms[1].Get("trade_no"); got != "trade-123" {
t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
}
if got := gotForms[1].Get("out_trade_no"); got != "" {
t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
}
}
func TestEasyPayRefundResponseErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
body string
want string
}{
{name: "html response", statusCode: http.StatusOK, body: "<html>bad config</html>", want: "non-JSON response (HTTP 200): <html>bad config</html>"},
{name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
{name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
{name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): <empty>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(tt.statusCode)
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL)
_, err := provider.Refund(context.Background(), payment.RefundRequest{
OrderID: "out-456",
Amount: "1.50",
})
if err == nil {
t.Fatal("Refund returned nil error")
}
if !strings.Contains(err.Error(), tt.want) {
t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
}
})
}
}
func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
t.Helper()
provider, err := NewEasyPay("test-instance", map[string]string{
"pid": "pid-1",
"pkey": "pkey-1",
"apiBase": apiBase,
"notifyUrl": "https://example.com/notify",
"returnUrl": "https://example.com/return",
})
if err != nil {
t.Fatalf("NewEasyPay: %v", err)
}
return provider
}
......@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cached",
Model: "gpt-5.2",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "Cached response"},
},
},
},
Usage: &ResponsesUsage{
InputTokens: 54006,
OutputTokens: 123,
TotalTokens: 54129,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 50688,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
assert.Equal(t, 3318, anth.Usage.InputTokens)
assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
assert.Equal(t, 123, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cached_clamp",
Model: "gpt-5.2",
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 100,
OutputTokens: 5,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 150,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
assert.Equal(t, 0, anth.Usage.InputTokens)
assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_456",
......@@ -209,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
assert.Equal(t, "tool_use", anth.Content[1].Type)
assert.Equal(t, "call_1", anth.Content[1].ID)
assert.Equal(t, "get_weather", anth.Content[1].Name)
assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
}
func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_read",
Model: "gpt-5.5",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "function_call",
CallID: "call_read",
Name: "Read",
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
require.Len(t, anth.Content, 1)
assert.Equal(t, "tool_use", anth.Content[0].Type)
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input))
}
func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_other",
Model: "gpt-5.5",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "function_call",
CallID: "call_other",
Name: "Search",
Arguments: `{"query":""}`,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
require.Len(t, anth.Content, 1)
assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input))
}
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
......@@ -343,6 +434,36 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type)
}
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
}, state)
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 54006,
OutputTokens: 123,
TotalTokens: 54129,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 50688,
},
},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, 3318, events[0].Usage.InputTokens)
assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
assert.Equal(t, 123, events[0].Usage.OutputTokens)
assert.Equal(t, "message_stop", events[1].Type)
}
func TestStreamingToolCall(t *testing.T) {
state := NewResponsesEventToAnthropicState()
......@@ -393,6 +514,41 @@ func TestStreamingToolCall(t *testing.T) {
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
}
func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"},
}, state)
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 0,
Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"},
}, state)
require.Len(t, events, 1)
assert.Equal(t, "content_block_start", events[0].Type)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.function_call_arguments.delta",
OutputIndex: 0,
Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
}, state)
assert.Len(t, events, 0)
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.function_call_arguments.done",
OutputIndex: 0,
Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
}, state)
require.Len(t, events, 2)
assert.Equal(t, "content_block_delta", events[0].Type)
assert.Equal(t, "input_json_delta", events[0].Delta.Type)
assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON)
assert.Equal(t, "content_block_stop", events[1].Type)
}
func TestStreamingReasoning(t *testing.T) {
state := NewResponsesEventToAnthropicState()
......
......@@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
Type: "tool_use",
ID: fromResponsesCallID(item.CallID),
Name: item.Name,
Input: json.RawMessage(item.Arguments),
Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments),
})
case "web_search_call":
toolUseID := "srvtoolu_" + item.ID
......@@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
if resp.Usage != nil {
out.Usage = AnthropicUsage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
}
if resp.Usage.InputTokensDetails != nil {
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
}
out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
}
return out
}
func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
if usage == nil {
return AnthropicUsage{}
}
cachedTokens := 0
if usage.InputTokensDetails != nil {
cachedTokens = usage.InputTokensDetails.CachedTokens
}
inputTokens := usage.InputTokens - cachedTokens
if inputTokens < 0 {
inputTokens = 0
}
return AnthropicUsage{
InputTokens: inputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: cachedTokens,
}
}
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
switch status {
case "incomplete":
......@@ -113,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
}
}
func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
if name != "Read" || raw == "" {
return json.RawMessage(raw)
}
var input map[string]json.RawMessage
if err := json.Unmarshal([]byte(raw), &input); err != nil {
return json.RawMessage(raw)
}
if pages, ok := input["pages"]; !ok || string(pages) != `""` {
return json.RawMessage(raw)
}
delete(input, "pages")
sanitized, err := json.Marshal(input)
if err != nil {
return json.RawMessage(raw)
}
return sanitized
}
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
......@@ -126,6 +164,8 @@ type ResponsesEventToAnthropicState struct {
ContentBlockIndex int
ContentBlockOpen bool
CurrentBlockType string // "text" | "thinking" | "tool_use"
CurrentToolName string
CurrentToolArgs string
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
OutputIndexToBlockIdx map[int]int
......@@ -165,7 +205,7 @@ func ResponsesEventToAnthropicEvents(
case "response.function_call_arguments.delta":
return resToAnthHandleFuncArgsDelta(evt, state)
case "response.function_call_arguments.done":
return resToAnthHandleBlockDone(state)
return resToAnthHandleFuncArgsDone(evt, state)
case "response.output_item.done":
return resToAnthHandleOutputItemDone(evt, state)
case "response.reasoning_summary_text.delta":
......@@ -262,6 +302,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
state.ContentBlockOpen = true
state.CurrentBlockType = "tool_use"
state.CurrentToolName = evt.Item.Name
state.CurrentToolArgs = ""
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
......@@ -342,6 +384,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
return nil
}
if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" {
state.CurrentToolArgs += evt.Delta
return nil
}
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
if !ok {
return nil
......@@ -357,6 +404,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
}}
}
func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
return resToAnthHandleBlockDone(state)
}
raw := evt.Arguments
if raw == "" {
raw = state.CurrentToolArgs
}
sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
if len(sanitized) == 0 {
return closeCurrentBlock(state)
}
idx := state.ContentBlockIndex
events := []AnthropicStreamEvent{{
Type: "content_block_delta",
Index: &idx,
Delta: &AnthropicDelta{
Type: "input_json_delta",
PartialJSON: string(sanitized),
},
}}
events = append(events, closeCurrentBlock(state)...)
return events
}
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Delta == "" {
return nil
......@@ -466,11 +540,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
stopReason := "end_turn"
if evt.Response != nil {
if evt.Response.Usage != nil {
state.InputTokens = evt.Response.Usage.InputTokens
state.OutputTokens = evt.Response.Usage.OutputTokens
if evt.Response.Usage.InputTokensDetails != nil {
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
}
usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
state.InputTokens = usage.InputTokens
state.OutputTokens = usage.OutputTokens
state.CacheReadInputTokens = usage.CacheReadInputTokens
}
switch evt.Response.Status {
case "incomplete":
......@@ -509,6 +582,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
idx := state.ContentBlockIndex
state.ContentBlockOpen = false
state.ContentBlockIndex++
state.CurrentToolName = ""
state.CurrentToolArgs = ""
return []AnthropicStreamEvent{{
Type: "content_block_stop",
Index: &idx,
......
......@@ -390,7 +390,7 @@ func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
var out []AnthropicTool
for _, t := range tools {
switch t.Type {
case "web_search":
case "web_search", "google_search", "web_search_20250305":
out = append(out, AnthropicTool{
Type: "web_search_20250305",
Name: "web_search",
......
......@@ -12,17 +12,23 @@ import "encoding/json"
// AnthropicRequest is the request body for POST /v1/messages.
type AnthropicRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
Messages []AnthropicMessage `json:"messages"`
Tools []AnthropicTool `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSeqs []string `json:"stop_sequences,omitempty"`
Thinking *AnthropicThinking `json:"thinking,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
Messages []AnthropicMessage `json:"messages"`
Tools []AnthropicTool `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSeqs []string `json:"stop_sequences,omitempty"`
Thinking *AnthropicThinking `json:"thinking,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
// Metadata 会被原样透传给上游。OAuth/Claude-Code 路径依赖 metadata.user_id
// 参与上游的"是否为官方 Claude Code 请求"判定;如果经由本结构体重新序列化
// 时丢弃该字段,网关侧后续的 metadata 重写(ensureClaudeOAuthMetadataUserID/
// RewriteUserIDWithMasking) 在 body 里拿不到起点,就无法重建一个合法的
// user_id,进而导致请求被归类为第三方 app。
Metadata json.RawMessage `json:"metadata,omitempty"`
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
}
......@@ -76,10 +82,18 @@ type AnthropicImageSource struct {
// AnthropicTool describes a tool available to the model.
type AnthropicTool struct {
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"`
}
// AnthropicCacheControl 对应 Anthropic API 的 cache_control 字段。
// ttl 默认由调用方决定;本项目策略见 claude.DefaultCacheControlTTL。
type AnthropicCacheControl struct {
Type string `json:"type"` // "ephemeral"
TTL string `json:"ttl,omitempty"` // "5m" / "1h" / 省略=默认 5m(由 Anthropic 判定)
}
// AnthropicResponse is the non-streaming response from POST /v1/messages.
......
......@@ -4,6 +4,12 @@ package claude
// Claude Code 客户端相关常量
// Beta header 常量
//
// 这里的常量对齐真实 Claude Code CLI 的最新流量(截至 2026-04)。
// 选型参考:与 Parrot (src/transform/cc_mimicry.py) 的 BETAS 保持一致,
// 原因:Anthropic 上游会基于 anthropic-beta 的完整集合判定请求来源;
// 缺少任何"官方 Claude Code 请求才会带"的 beta,都会被降级到第三方额度,
// 对应报错:`Third-party apps now draw from your extra usage, not your plan limits.`
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
......@@ -12,6 +18,13 @@ const (
BetaTokenCounting = "token-counting-2024-11-01"
BetaContext1M = "context-1m-2025-08-07"
BetaFastMode = "fast-mode-2026-02-01"
// 新增(对齐官方 CLI 2.1.9x 以来的流量)
BetaPromptCachingScope = "prompt-caching-scope-2026-01-05"
BetaEffort = "effort-2025-11-24"
BetaRedactThinking = "redact-thinking-2026-02-12"
BetaContextManagement = "context-management-2025-06-27"
BetaExtendedCacheTTL = "extended-cache-ttl-2025-04-11"
)
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
......@@ -44,11 +57,43 @@ const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," +
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultCacheControlTTL 是网关代理为自己生成的 cache_control 块默认使用的 ttl。
// 真实 Claude Code CLI 当前使用 "1h",但本仓策略是"客户端透传 ttl 优先;
// 客户端缺省时统一使用 5m",这样既不浪费 1h 缓存额度,也保留客户端自定义能力。
const DefaultCacheControlTTL = "5m"
// CLICurrentVersion 是 sub2api 当前对外伪装的 Claude Code CLI 版本号(三段 semver)。
// 用于 billing attribution block 中的 cc_version=X.Y.Z.{fp} 前缀以及 fingerprint 计算。
// 必须与 DefaultHeaders["User-Agent"] 中的版本号严格一致;不一致会被 Anthropic 判第三方。
const CLICurrentVersion = "2.1.92"
// FullClaudeCodeMimicryBetas 返回最"像"真实 Claude Code CLI 的完整 beta 列表,
// 用于 OAuth 账号伪装成 Claude Code 时使用。
// 顺序与真实 CLI 抓包一致。
//
// 使用建议:
// - OAuth 账号 + 非 haiku:追加这整份列表,再按需保留 client 带来的 beta。
// - OAuth 账号 + haiku:Anthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
func FullClaudeCodeMimicryBetas() []string {
return []string{
BetaClaudeCode,
BetaOAuth,
BetaInterleavedThinking,
BetaPromptCachingScope,
BetaEffort,
BetaRedactThinking,
BetaContextManagement,
BetaExtendedCacheTTL,
}
}
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent": "claude-cli/2.1.22 (external, cli)",
// 版本参考:对齐 Parrot (src/transform/cc_mimicry.py:49) 的 CLI_USER_AGENT。
"User-Agent": "claude-cli/2.1.92 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",
......
......@@ -55,4 +55,8 @@ const (
// ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22")
ClaudeCodeVersion Key = "ctx_claude_code_version"
// ResponseCaptureBuffer 用于在 streaming 响应中收集 assistant 文本,供 request_capture 功能使用。
// 值类型为 *strings.Builder,由 handler 层注入,service 层只负责追加文本。
ResponseCaptureBuffer Key = "ctx_response_capture_buffer"
)
package repository
import "testing"
func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRelevant(t *testing.T) {
updates := map[string]any{
"openai_compact_supported": true,
"openai_compact_checked_at": "2026-04-10T10:00:00Z",
}
if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
t.Fatalf("expected compact capability updates to enqueue scheduler outbox")
}
}
This diff is collapsed.
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 {
t.Helper()
rows, err := client.QueryContext(ctx, query, args...)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next(), "expected one row")
var value float64
require.NoError(t, rows.Scan(&value))
require.NoError(t, rows.Err())
return value
}
func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int {
t.Helper()
rows, err := client.QueryContext(ctx, query, args...)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next(), "expected one row")
var value int
require.NoError(t, rows.Scan(&value))
require.NoError(t, rows.Err())
return value
}
func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 5.5,
Concurrency: 5,
})
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
_, err := client.ExecContext(txCtx, `
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
require.NoError(t, err)
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
require.NoError(t, err)
require.InDelta(t, 12.34, transferred, 1e-9)
require.InDelta(t, 17.84, balance, 1e-9)
affQuota := querySingleFloat(t, txCtx, client,
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID)
require.InDelta(t, 0.0, affQuota, 1e-9)
persistedBalance := querySingleFloat(t, txCtx, client,
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
require.InDelta(t, 17.84, persistedBalance, 1e-9)
ledgerCount := querySingleInt(t, txCtx, client,
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
require.Equal(t, 1, ledgerCount)
}
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
// cross-layer tx propagation invariant: when AccrueQuota is called with a ctx
// that already carries a transaction (via dbent.NewTxContext), repo.withTx
// must reuse that tx rather than opening a nested one. If this invariant
// breaks, AccrueQuota would commit independently and survive a rollback of
// the outer tx, which would violate payment_fulfillment's all-or-nothing
// semantics.
func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
ctx := context.Background()
outerTx, err := integrationEntClient.Tx(ctx)
require.NoError(t, err, "begin outer tx")
// Defensive cleanup: if any require.* below fires before the explicit
// Rollback, this prevents the tx from leaking until container teardown.
// Rollback is idempotent at the driver level (extra rollback returns an
// error we ignore).
t.Cleanup(func() { _ = outerTx.Rollback() })
client := outerTx.Client()
txCtx := dbent.NewTxContext(ctx, outerTx)
inviter := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-inviter-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
})
invitee := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-invitee-%d@example.com", time.Now().UnixNano()+1),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
})
repo := NewAffiliateRepository(client, integrationDB)
_, err = repo.EnsureUserAffiliate(txCtx, inviter.ID)
require.NoError(t, err)
_, err = repo.EnsureUserAffiliate(txCtx, invitee.ID)
require.NoError(t, err)
bound, err := repo.BindInviter(txCtx, invitee.ID, inviter.ID)
require.NoError(t, err)
require.True(t, bound, "invitee must bind to inviter")
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
require.NoError(t, err)
require.True(t, applied, "AccrueQuota must report applied=true")
// Visible inside the outer tx.
innerQuota := querySingleFloat(t, txCtx, client,
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", inviter.ID)
require.InDelta(t, 3.5, innerQuota, 1e-9)
// Roll back the outer tx; if AccrueQuota had opened its own inner tx and
// committed it, the rows would still be visible to the global client.
require.NoError(t, outerTx.Rollback())
rows, err := integrationEntClient.QueryContext(ctx,
"SELECT COUNT(*) FROM user_affiliates WHERE user_id IN ($1, $2)",
inviter.ID, invitee.ID)
require.NoError(t, err)
defer func() { _ = rows.Close() }()
require.True(t, rows.Next())
var postRollbackCount int
require.NoError(t, rows.Scan(&postRollbackCount))
require.Equal(t, 0, postRollbackCount,
"AccrueQuota must propagate the outer tx — found persisted rows after rollback")
}
func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 3.21,
Concurrency: 5,
})
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
_, err := client.ExecContext(txCtx, `
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
require.NoError(t, err)
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty)
require.InDelta(t, 0.0, transferred, 1e-9)
require.InDelta(t, 0.0, balance, 1e-9)
persistedBalance := querySingleFloat(t, txCtx, client,
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
require.InDelta(t, 3.21, persistedBalance, 1e-9)
}
// TestAffiliateRepository_AdminCustomCode covers the success path of admin
// invite-code rewrite + reset within a shared test transaction:
// - UpdateUserAffCode replaces aff_code, sets aff_code_custom=true, lookup works
// - the old code can no longer be found
// - ResetUserAffCode reverts aff_code_custom and assigns a new system-format code
//
// The conflict path (duplicate code → ErrAffiliateCodeTaken) lives in its own
// test because a unique-violation aborts the surrounding Postgres tx, which
// would poison subsequent assertions in the same transaction.
func TestAffiliateRepository_AdminCustomCode(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-custom-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
original, err := repo.EnsureUserAffiliate(txCtx, u.ID)
require.NoError(t, err)
require.False(t, original.AffCodeCustom, "system-generated codes start as non-custom")
originalCode := original.AffCode
// Rewrite to a custom code
customCode := fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)
require.NoError(t, repo.UpdateUserAffCode(txCtx, u.ID, customCode))
updated, err := repo.EnsureUserAffiliate(txCtx, u.ID)
require.NoError(t, err)
require.Equal(t, customCode, updated.AffCode)
require.True(t, updated.AffCodeCustom)
// Lookup by new custom code finds the user
byCode, err := repo.GetAffiliateByCode(txCtx, customCode)
require.NoError(t, err)
require.Equal(t, u.ID, byCode.UserID)
// Old system code should no longer match
_, err = repo.GetAffiliateByCode(txCtx, originalCode)
require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
// Reset back to a fresh system code, clears custom flag
newSysCode, err := repo.ResetUserAffCode(txCtx, u.ID)
require.NoError(t, err)
require.NotEqual(t, customCode, newSysCode)
reset, err := repo.EnsureUserAffiliate(txCtx, u.ID)
require.NoError(t, err)
require.Equal(t, newSysCode, reset.AffCode)
require.False(t, reset.AffCodeCustom)
// The old custom code is now free again
_, err = repo.GetAffiliateByCode(txCtx, customCode)
require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
}
// TestAffiliateRepository_AdminCustomCode_Conflict isolates the unique-violation
// path. PostgreSQL aborts the enclosing tx when a unique constraint fires, so
// this test must be the only assertion and run in its own tx — production
// callers each have their own outer tx, so this matches real behavior.
func TestAffiliateRepository_AdminCustomCode_Conflict(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
taker := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-conflict-taker-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
requester := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-conflict-req-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
takenCode := fmt.Sprintf("HOT%09d", time.Now().UnixNano()%1_000_000_000)
require.NoError(t, repo.UpdateUserAffCode(txCtx, taker.ID, takenCode))
// Now requester tries to grab the same code → conflict.
err := repo.UpdateUserAffCode(txCtx, requester.ID, takenCode)
require.ErrorIs(t, err, service.ErrAffiliateCodeTaken)
}
// TestAffiliateRepository_AdminRebateRate covers per-user exclusive rate
// set/clear and the Batch variant including NULL semantics.
func TestAffiliateRepository_AdminRebateRate(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
u1 := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-rate-%d-a@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
u2 := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-rate-%d-b@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
// Set exclusive rate for u1
rate := 42.5
require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, &rate))
got, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
require.NoError(t, err)
require.NotNil(t, got.AffRebateRatePercent)
require.InDelta(t, 42.5, *got.AffRebateRatePercent, 1e-9)
// Clear exclusive rate
require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, nil))
cleared, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
require.NoError(t, err)
require.Nil(t, cleared.AffRebateRatePercent)
// Batch set both users
batchRate := 15.0
require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, &batchRate))
for _, uid := range []int64{u1.ID, u2.ID} {
v, err := repo.EnsureUserAffiliate(txCtx, uid)
require.NoError(t, err)
require.NotNil(t, v.AffRebateRatePercent)
require.InDelta(t, 15.0, *v.AffRebateRatePercent, 1e-9)
}
// Batch clear
require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, nil))
for _, uid := range []int64{u1.ID, u2.ID} {
v, err := repo.EnsureUserAffiliate(txCtx, uid)
require.NoError(t, err)
require.Nil(t, v.AffRebateRatePercent)
}
}
// TestAffiliateRepository_ListUsersWithCustomSettings verifies the admin list
// only includes users with at least one override applied.
func TestAffiliateRepository_ListUsersWithCustomSettings(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewAffiliateRepository(client, integrationDB)
// User without any custom config — should NOT appear in the list.
plainEmail := fmt.Sprintf("affiliate-plain-%d@example.com", time.Now().UnixNano())
uPlain := mustCreateUser(t, client, &service.User{
Email: plainEmail, PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
_, err := repo.EnsureUserAffiliate(txCtx, uPlain.ID)
require.NoError(t, err)
// User with a custom code — should appear.
uCode := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-codeonly-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
require.NoError(t, repo.UpdateUserAffCode(txCtx, uCode.ID, fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)))
// User with only an exclusive rate — should appear.
uRate := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("affiliate-rateonly-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Role: service.RoleUser, Status: service.StatusActive,
})
r := 33.3
require.NoError(t, repo.SetUserRebateRate(txCtx, uRate.ID, &r))
entries, total, err := repo.ListUsersWithCustomSettings(txCtx, service.AffiliateAdminFilter{
Page: 1, PageSize: 100,
})
require.NoError(t, err)
// Build a quick lookup to assert per-user attributes (other tests may have
// inserted custom rows in the same DB; we only care about our 3).
byUserID := make(map[int64]service.AffiliateAdminEntry, len(entries))
for _, e := range entries {
byUserID[e.UserID] = e
}
require.NotContains(t, byUserID, uPlain.ID, "users without overrides must not appear")
codeEntry, ok := byUserID[uCode.ID]
require.True(t, ok, "custom-code user missing from list")
require.True(t, codeEntry.AffCodeCustom)
require.Nil(t, codeEntry.AffRebateRatePercent)
rateEntry, ok := byUserID[uRate.ID]
require.True(t, ok, "custom-rate user missing from list")
require.False(t, rateEntry.AffCodeCustom)
require.NotNil(t, rateEntry.AffRebateRatePercent)
require.InDelta(t, 33.3, *rateEntry.AffRebateRatePercent, 1e-9)
require.GreaterOrEqual(t, total, int64(2), "total must include at least our 2 custom rows")
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment