"vscode:/vscode.git/clone" did not exist on "6cd7c60549d565bd07ed4fdfd9f31ff4d60d920c"
Unverified Commit 3f05ef2a authored by Oliver Li's avatar Oliver Li Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into vertex

parents 6d11f9ed c056db74
docs/claude-relay-service/ docs/claude-relay-service/
.codex
# =================== # ===================
# Go 后端 # Go 后端
......
...@@ -186,6 +186,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -186,6 +186,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance, DefaultBalance: settings.DefaultBalance,
AffiliateRebateRate: settings.AffiliateRebateRate, AffiliateRebateRate: settings.AffiliateRebateRate,
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit, DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback, EnableModelFallback: settings.EnableModelFallback,
...@@ -342,6 +345,9 @@ type UpdateSettingsRequest struct { ...@@ -342,6 +345,9 @@ type UpdateSettingsRequest struct {
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"` AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
...@@ -485,6 +491,33 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -485,6 +491,33 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if affiliateRebateRate > service.AffiliateRebateRateMax { if affiliateRebateRate > service.AffiliateRebateRateMax {
affiliateRebateRate = service.AffiliateRebateRateMax affiliateRebateRate = service.AffiliateRebateRateMax
} }
affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours
if req.AffiliateRebateFreezeHours != nil {
affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours
}
if affiliateRebateFreezeHours < 0 {
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault
}
if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax {
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax
}
affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays
if req.AffiliateRebateDurationDays != nil {
affiliateRebateDurationDays = *req.AffiliateRebateDurationDays
}
if affiliateRebateDurationDays < 0 {
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault
}
if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax {
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax
}
affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap
if req.AffiliateRebatePerInviteeCap != nil {
affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap
}
if affiliateRebatePerInviteeCap < 0 {
affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
}
// 通用表格配置:兼容旧客户端未传字段时保留当前值。 // 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 { if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
...@@ -1137,6 +1170,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -1137,6 +1170,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance, DefaultBalance: req.DefaultBalance,
AffiliateRebateRate: affiliateRebateRate, AffiliateRebateRate: affiliateRebateRate,
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
AffiliateRebateDurationDays: affiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
DefaultUserRPMLimit: req.DefaultUserRPMLimit, DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback, EnableModelFallback: req.EnableModelFallback,
...@@ -1458,6 +1494,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -1458,6 +1494,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance, DefaultBalance: updatedSettings.DefaultBalance,
AffiliateRebateRate: updatedSettings.AffiliateRebateRate, AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit, DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions, DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback, EnableModelFallback: updatedSettings.EnableModelFallback,
...@@ -1768,6 +1807,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -1768,6 +1807,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AffiliateRebateRate != after.AffiliateRebateRate { if before.AffiliateRebateRate != after.AffiliateRebateRate {
changed = append(changed, "affiliate_rebate_rate") changed = append(changed, "affiliate_rebate_rate")
} }
if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours {
changed = append(changed, "affiliate_rebate_freeze_hours")
}
if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays {
changed = append(changed, "affiliate_rebate_duration_days")
}
if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
changed = append(changed, "affiliate_rebate_per_invitee_cap")
}
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) { if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions") changed = append(changed, "default_subscriptions")
} }
......
...@@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( ...@@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
type completeLinuxDoOAuthRequest struct { type completeLinuxDoOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"` InvitationCode string `json:"invitation_code" binding:"required"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
} }
...@@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { ...@@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct { ...@@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct {
VerifyCode string `json:"verify_code,omitempty"` VerifyCode string `json:"verify_code,omitempty"`
Password string `json:"password" binding:"required,min=6"` Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"` InvitationCode string `json:"invitation_code,omitempty"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
} }
...@@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) ...@@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
user, user,
strings.TrimSpace(req.InvitationCode), strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType), strings.TrimSpace(session.ProviderType),
strings.TrimSpace(req.AffCode),
); err != nil { ); err != nil {
_ = tx.Rollback() _ = tx.Rollback()
if rollbackCreatedUser(err) { if rollbackCreatedUser(err) {
......
...@@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession( ...@@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
type completeOIDCOAuthRequest struct { type completeOIDCOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"` InvitationCode string `json:"invitation_code" binding:"required"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
} }
...@@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { ...@@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService ...@@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService
type completeWeChatOAuthRequest struct { type completeWeChatOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"` InvitationCode string `json:"invitation_code" binding:"required"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
} }
...@@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { ...@@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
return return
} }
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -109,6 +109,9 @@ type SystemSettings struct { ...@@ -109,6 +109,9 @@ type SystemSettings struct {
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate float64 `json:"affiliate_rebate_rate"` AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
......
...@@ -25,6 +25,7 @@ const ( ...@@ -25,6 +25,7 @@ const (
easypayStatusPaid = 1 easypayStatusPaid = 1
easypayHTTPTimeout = 10 * time.Second easypayHTTPTimeout = 10 * time.Second
maxEasypayResponseSize = 1 << 20 // 1MB maxEasypayResponseSize = 1 << 20 // 1MB
maxEasypayErrorSummary = 512
tradeStatusSuccess = "TRADE_SUCCESS" tradeStatusSuccess = "TRADE_SUCCESS"
signTypeMD5 = "MD5" signTypeMD5 = "MD5"
paymentModePopup = "popup" paymentModePopup = "popup"
...@@ -42,17 +43,55 @@ type EasyPay struct { ...@@ -42,17 +43,55 @@ type EasyPay struct {
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay // config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) { func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} { for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
if config[k] == "" { if strings.TrimSpace(config[k]) == "" {
return nil, fmt.Errorf("easypay config missing required key: %s", k) return nil, fmt.Errorf("easypay config missing required key: %s", k)
} }
} }
cfg := make(map[string]string, len(config))
for k, v := range config {
cfg[k] = v
}
cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
return &EasyPay{ return &EasyPay{
instanceID: instanceID, instanceID: instanceID,
config: config, config: cfg,
httpClient: &http.Client{Timeout: easypayHTTPTimeout}, httpClient: &http.Client{Timeout: easypayHTTPTimeout},
}, nil }, nil
} }
func normalizeEasyPayAPIBase(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return ""
}
if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.RawPath = ""
parsed.Path = trimEasyPayEndpointPath(parsed.Path)
return strings.TrimRight(parsed.String(), "/")
}
return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
}
func trimEasyPayEndpointPath(path string) string {
path = strings.TrimRight(strings.TrimSpace(path), "/")
lower := strings.ToLower(path)
for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
if strings.HasSuffix(lower, endpoint) {
return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
}
}
return path
}
func (e *EasyPay) apiBase() string {
if e == nil {
return ""
}
return normalizeEasyPayAPIBase(e.config["apiBase"])
}
func (e *EasyPay) Name() string { return "EasyPay" } func (e *EasyPay) Name() string { return "EasyPay" }
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay } func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
func (e *EasyPay) SupportedTypes() []payment.PaymentType { func (e *EasyPay) SupportedTypes() []payment.PaymentType {
...@@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym ...@@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
for k, v := range params { for k, v := range params {
q.Set(k, v) q.Set(k, v)
} }
base := strings.TrimRight(e.config["apiBase"], "/") payURL := e.apiBase() + "/submit.php?" + q.Encode()
payURL := base + "/submit.php?" + q.Encode()
return &payment.CreatePaymentResponse{PayURL: payURL}, nil return &payment.CreatePaymentResponse{PayURL: payURL}, nil
} }
...@@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen ...@@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
params["sign"] = easyPaySign(params, e.config["pkey"]) params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5 params["sign_type"] = signTypeMD5
body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params) body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
if err != nil { if err != nil {
return nil, fmt.Errorf("easypay create: %w", err) return nil, fmt.Errorf("easypay create: %w", err)
} }
...@@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer ...@@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
"act": "order", "pid": e.config["pid"], "act": "order", "pid": e.config["pid"],
"key": e.config["pkey"], "out_trade_no": tradeNo, "key": e.config["pkey"], "out_trade_no": tradeNo,
} }
body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params) body, err := e.post(ctx, e.apiBase()+"/api.php", params)
if err != nil { if err != nil {
return nil, fmt.Errorf("easypay query: %w", err) return nil, fmt.Errorf("easypay query: %w", err)
} }
...@@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st ...@@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
} }
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) { func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
params := map[string]string{ attempts := e.refundAttempts(req)
"pid": e.config["pid"], "key": e.config["pkey"], if len(attempts) == 0 {
"trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount, return nil, fmt.Errorf("easypay refund missing order identifier")
} }
body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params) var firstErr error
for i, attempt := range attempts {
body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
if err != nil { if err != nil {
return nil, fmt.Errorf("easypay refund: %w", err) return nil, fmt.Errorf("easypay refund request: %w", err)
}
if err := parseEasyPayRefundResponse(status, body); err != nil {
if firstErr == nil {
firstErr = err
} }
if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
continue
}
return nil, err
}
return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
}
return nil, firstErr
}
type easyPayRefundAttempt struct {
params map[string]string
refundID string
}
func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
base := map[string]string{
"pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
}
var attempts []easyPayRefundAttempt
if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
params := cloneStringMap(base)
params["out_trade_no"] = orderID
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
}
if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
params := cloneStringMap(base)
params["trade_no"] = tradeNo
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
}
return attempts
}
func cloneStringMap(in map[string]string) map[string]string {
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func isEasyPayRefundOrderNotFound(err error) bool {
if err == nil {
return false
}
msg := err.Error()
lower := strings.ToLower(msg)
return strings.Contains(msg, "订单编号不存在") ||
strings.Contains(msg, "订单不存在") ||
strings.Contains(lower, "order not found") ||
strings.Contains(lower, "not exist")
}
func parseEasyPayRefundResponse(status int, body []byte) error {
summary := summarizeEasyPayResponse(body)
if status < http.StatusOK || status >= http.StatusMultipleChoices {
return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
}
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
}
lower := strings.ToLower(trimmed)
if strings.HasPrefix(lower, "<!doctype html") || strings.HasPrefix(lower, "<html") ||
(strings.HasPrefix(lower, "<") && strings.Contains(lower, "html")) {
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
}
var resp struct { var resp struct {
Code int `json:"code"` Code any `json:"code"`
Msg string `json:"msg"` Msg string `json:"msg"`
} }
if err := json.Unmarshal(body, &resp); err != nil { if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("easypay parse refund: %w", err) return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
} }
if resp.Code != easypayCodeSuccess { if !easyPayResponseCodeIsSuccess(resp.Code) {
return nil, fmt.Errorf("easypay refund failed: %s", resp.Msg) msg := strings.TrimSpace(resp.Msg)
if msg == "" {
msg = summary
}
return fmt.Errorf("easypay refund failed (HTTP %d): %s", status, msg)
}
return nil
}
func easyPayResponseCodeIsSuccess(code any) bool {
switch v := code.(type) {
case float64:
return int(v) == easypayCodeSuccess
case string:
n, err := strconv.Atoi(strings.TrimSpace(v))
return err == nil && n == easypayCodeSuccess
default:
return false
} }
return &payment.RefundResponse{RefundID: req.TradeNo, Status: payment.ProviderStatusSuccess}, nil }
func summarizeEasyPayResponse(body []byte) string {
summary := strings.Join(strings.Fields(string(body)), " ")
if summary == "" {
return "<empty>"
}
if len(summary) > maxEasypayErrorSummary {
return summary[:maxEasypayErrorSummary] + "..."
}
return summary
} }
func (e *EasyPay) resolveCID(paymentType string) string { func (e *EasyPay) resolveCID(paymentType string) string {
...@@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string { ...@@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string {
} }
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) { func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
body, _, err := e.postRaw(ctx, endpoint, params)
return body, err
}
func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
form := url.Values{} form := url.Values{}
for k, v := range params { for k, v := range params {
form.Set(k, v) form.Set(k, v)
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil { if err != nil {
return nil, err return nil, 0, err
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := e.httpClient.Do(req) client := e.httpClient
if client == nil {
client = &http.Client{Timeout: easypayHTTPTimeout}
}
resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, 0, err
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize)) body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
if err != nil {
return nil, resp.StatusCode, err
}
return body, resp.StatusCode, nil
} }
func easyPaySign(params map[string]string, pkey string) string { func easyPaySign(params map[string]string, pkey string) string {
......
package provider
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeEasyPayAPIBase(t *testing.T) {
t.Parallel()
tests := []struct {
input string
want string
}{
{input: "https://zpayz.cn", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
t.Parallel()
var gotPath string
var gotQuery url.Values
var gotForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotQuery = r.URL.Query()
if err := r.ParseForm(); err != nil {
t.Errorf("ParseForm: %v", err)
}
gotForm = r.PostForm
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL+"/mapi.php")
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
TradeNo: "trade-123",
OrderID: "out-456",
Amount: "1.50",
})
if err != nil {
t.Fatalf("Refund returned error: %v", err)
}
if resp == nil || resp.Status != payment.ProviderStatusSuccess {
t.Fatalf("Refund response = %+v, want success", resp)
}
if gotPath != "/api.php" {
t.Fatalf("refund path = %q, want /api.php", gotPath)
}
if gotQuery.Get("act") != "refund" {
t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
}
for key, want := range map[string]string{
"pid": "pid-1",
"key": "pkey-1",
"out_trade_no": "out-456",
"money": "1.50",
} {
if got := gotForm.Get(key); got != want {
t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
}
}
if got := gotForm.Get("trade_no"); got != "" {
t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
}
}
func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
t.Parallel()
var gotForms []url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api.php" {
t.Errorf("refund path = %q, want /api.php", r.URL.Path)
}
if r.URL.Query().Get("act") != "refund" {
t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
}
if err := r.ParseForm(); err != nil {
t.Errorf("ParseForm: %v", err)
}
gotForms = append(gotForms, r.PostForm)
w.Header().Set("Content-Type", "application/json")
if len(gotForms) == 1 {
_, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
return
}
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL+"/mapi.php")
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
TradeNo: "trade-123",
OrderID: "out-456",
Amount: "1.50",
})
if err != nil {
t.Fatalf("Refund returned error: %v", err)
}
if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
}
if len(gotForms) != 2 {
t.Fatalf("refund attempts = %d, want 2", len(gotForms))
}
if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
}
if got := gotForms[0].Get("trade_no"); got != "" {
t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
}
if got := gotForms[1].Get("trade_no"); got != "trade-123" {
t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
}
if got := gotForms[1].Get("out_trade_no"); got != "" {
t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
}
}
func TestEasyPayRefundResponseErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
body string
want string
}{
{name: "html response", statusCode: http.StatusOK, body: "<html>bad config</html>", want: "non-JSON response (HTTP 200): <html>bad config</html>"},
{name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
{name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
{name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): <empty>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(tt.statusCode)
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL)
_, err := provider.Refund(context.Background(), payment.RefundRequest{
OrderID: "out-456",
Amount: "1.50",
})
if err == nil {
t.Fatal("Refund returned nil error")
}
if !strings.Contains(err.Error(), tt.want) {
t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
}
})
}
}
func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
t.Helper()
provider, err := NewEasyPay("test-instance", map[string]string{
"pid": "pid-1",
"pkey": "pkey-1",
"apiBase": apiBase,
"notifyUrl": "https://example.com/notify",
"returnUrl": "https://example.com/return",
})
if err != nil {
t.Fatalf("NewEasyPay: %v", err)
}
return provider
}
...@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) { ...@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
assert.Equal(t, 5, anth.Usage.OutputTokens) assert.Equal(t, 5, anth.Usage.OutputTokens)
} }
func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cached",
Model: "gpt-5.2",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "Cached response"},
},
},
},
Usage: &ResponsesUsage{
InputTokens: 54006,
OutputTokens: 123,
TotalTokens: 54129,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 50688,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
assert.Equal(t, 3318, anth.Usage.InputTokens)
assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
assert.Equal(t, 123, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cached_clamp",
Model: "gpt-5.2",
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 100,
OutputTokens: 5,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 150,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
assert.Equal(t, 0, anth.Usage.InputTokens)
assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_ToolUse(t *testing.T) { func TestResponsesToAnthropic_ToolUse(t *testing.T) {
resp := &ResponsesResponse{ resp := &ResponsesResponse{
ID: "resp_456", ID: "resp_456",
...@@ -343,6 +392,36 @@ func TestStreamingTextOnly(t *testing.T) { ...@@ -343,6 +392,36 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type) assert.Equal(t, "message_stop", events[1].Type)
} }
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
}, state)
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 54006,
OutputTokens: 123,
TotalTokens: 54129,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 50688,
},
},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, 3318, events[0].Usage.InputTokens)
assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
assert.Equal(t, 123, events[0].Usage.OutputTokens)
assert.Equal(t, "message_stop", events[1].Type)
}
func TestStreamingToolCall(t *testing.T) { func TestStreamingToolCall(t *testing.T) {
state := NewResponsesEventToAnthropicState() state := NewResponsesEventToAnthropicState()
......
...@@ -84,16 +84,32 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo ...@@ -84,16 +84,32 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks) out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
if resp.Usage != nil { if resp.Usage != nil {
out.Usage = AnthropicUsage{ out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
} }
if resp.Usage.InputTokensDetails != nil {
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens return out
}
func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
if usage == nil {
return AnthropicUsage{}
}
cachedTokens := 0
if usage.InputTokensDetails != nil {
cachedTokens = usage.InputTokensDetails.CachedTokens
} }
inputTokens := usage.InputTokens - cachedTokens
if inputTokens < 0 {
inputTokens = 0
} }
return out return AnthropicUsage{
InputTokens: inputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: cachedTokens,
}
} }
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string { func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
...@@ -466,11 +482,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo ...@@ -466,11 +482,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
stopReason := "end_turn" stopReason := "end_turn"
if evt.Response != nil { if evt.Response != nil {
if evt.Response.Usage != nil { if evt.Response.Usage != nil {
state.InputTokens = evt.Response.Usage.InputTokens usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
state.OutputTokens = evt.Response.Usage.OutputTokens state.InputTokens = usage.InputTokens
if evt.Response.Usage.InputTokensDetails != nil { state.OutputTokens = usage.OutputTokens
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens state.CacheReadInputTokens = usage.CacheReadInputTokens
}
} }
switch evt.Response.Status { switch evt.Response.Status {
case "incomplete": case "incomplete":
......
...@@ -86,17 +86,21 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID ...@@ -86,17 +86,21 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
return bound, nil return bound, nil
} }
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) { func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
if amount <= 0 { if amount <= 0 {
return false, nil return false, nil
} }
var applied bool var applied bool
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
res, err := txClient.ExecContext(txCtx, // freezeHours > 0: add to frozen quota; == 0: add to available quota directly
"UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2", var updateSQL string
amount, inviterID, if freezeHours > 0 {
) updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
} else {
updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
}
res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID)
if err != nil { if err != nil {
return err return err
} }
...@@ -106,11 +110,20 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite ...@@ -106,11 +110,20 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
return nil return nil
} }
if freezeHours > 0 {
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
inviterID, amount, inviteeUserID, freezeHours); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
}
} else {
if _, err = txClient.ExecContext(txCtx, ` if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err) return fmt.Errorf("insert affiliate accrue ledger: %w", err)
} }
}
applied = true applied = true
return nil return nil
...@@ -121,6 +134,76 @@ VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); ...@@ -121,6 +134,76 @@ VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID);
return applied, nil return applied, nil
} }
func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx,
`SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`,
inviterID, inviteeUserID)
if err != nil {
return 0, fmt.Errorf("query accrued rebate from invitee: %w", err)
}
defer func() { _ = rows.Close() }()
var total float64
if rows.Next() {
if err := rows.Scan(&total); err != nil {
return 0, err
}
}
return total, rows.Close()
}
func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) {
var thawed float64
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
var err error
thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID)
return err
})
return thawed, err
}
// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx.
func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) {
rows, err := txClient.QueryContext(txCtx, `
WITH matured AS (
UPDATE user_affiliate_ledger
SET frozen_until = NULL, updated_at = NOW()
WHERE user_id = $1
AND frozen_until IS NOT NULL
AND frozen_until <= NOW()
RETURNING amount
)
SELECT COALESCE(SUM(amount), 0) FROM matured`, userID)
if err != nil {
return 0, fmt.Errorf("thaw frozen quota: %w", err)
}
defer func() { _ = rows.Close() }()
var thawed float64
if rows.Next() {
if err := rows.Scan(&thawed); err != nil {
return 0, err
}
}
if err := rows.Close(); err != nil {
return 0, err
}
if thawed <= 0 {
return 0, nil
}
_, err = txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_quota = aff_quota + $1,
aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0),
updated_at = NOW()
WHERE user_id = $2`, thawed, userID)
if err != nil {
return 0, fmt.Errorf("move thawed quota: %w", err)
}
return thawed, nil
}
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) { func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
var transferred float64 var transferred float64
var newBalance float64 var newBalance float64
...@@ -130,6 +213,11 @@ func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID ...@@ -130,6 +213,11 @@ func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID
return err return err
} }
// Thaw any matured frozen quota before transfer.
if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil {
return fmt.Errorf("thaw before transfer: %w", err)
}
rows, err := txClient.QueryContext(txCtx, ` rows, err := txClient.QueryContext(txCtx, `
WITH claimed AS ( WITH claimed AS (
SELECT aff_quota::double precision AS amount SELECT aff_quota::double precision AS amount
...@@ -211,10 +299,16 @@ func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, ...@@ -211,10 +299,16 @@ func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64,
SELECT ua.user_id, SELECT ua.user_id,
COALESCE(u.email, ''), COALESCE(u.email, ''),
COALESCE(u.username, ''), COALESCE(u.username, ''),
ua.created_at ua.created_at,
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate
FROM user_affiliates ua FROM user_affiliates ua
LEFT JOIN users u ON u.id = ua.user_id LEFT JOIN users u ON u.id = ua.user_id
LEFT JOIN user_affiliate_ledger ual
ON ual.user_id = $1
AND ual.source_user_id = ua.user_id
AND ual.action = 'accrue'
WHERE ua.inviter_id = $1 WHERE ua.inviter_id = $1
GROUP BY ua.user_id, u.email, u.username, ua.created_at
ORDER BY ua.created_at DESC ORDER BY ua.created_at DESC
LIMIT $2`, inviterID, limit) LIMIT $2`, inviterID, limit)
if err != nil { if err != nil {
...@@ -226,7 +320,7 @@ LIMIT $2`, inviterID, limit) ...@@ -226,7 +320,7 @@ LIMIT $2`, inviterID, limit)
for rows.Next() { for rows.Next() {
var item service.AffiliateInvitee var item service.AffiliateInvitee
var createdAt time.Time var createdAt time.Time
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil { if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil {
return nil, err return nil, err
} }
item.CreatedAt = &createdAt item.CreatedAt = &createdAt
...@@ -299,6 +393,7 @@ SELECT user_id, ...@@ -299,6 +393,7 @@ SELECT user_id,
inviter_id, inviter_id,
aff_count, aff_count,
aff_quota::double precision, aff_quota::double precision,
aff_frozen_quota::double precision,
aff_history_quota::double precision, aff_history_quota::double precision,
created_at, created_at,
updated_at updated_at
...@@ -326,6 +421,7 @@ WHERE user_id = $1`, userID) ...@@ -326,6 +421,7 @@ WHERE user_id = $1`, userID)
&inviterID, &inviterID,
&out.AffCount, &out.AffCount,
&out.AffQuota, &out.AffQuota,
&out.AffFrozenQuota,
&out.AffHistoryQuota, &out.AffHistoryQuota,
&out.CreatedAt, &out.CreatedAt,
&out.UpdatedAt, &out.UpdatedAt,
...@@ -351,6 +447,7 @@ SELECT user_id, ...@@ -351,6 +447,7 @@ SELECT user_id,
inviter_id, inviter_id,
aff_count, aff_count,
aff_quota::double precision, aff_quota::double precision,
aff_frozen_quota::double precision,
aff_history_quota::double precision, aff_history_quota::double precision,
created_at, created_at,
updated_at updated_at
...@@ -380,6 +477,7 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code))) ...@@ -380,6 +477,7 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
&inviterID, &inviterID,
&out.AffCount, &out.AffCount,
&out.AffQuota, &out.AffQuota,
&out.AffFrozenQuota,
&out.AffHistoryQuota, &out.AffHistoryQuota,
&out.CreatedAt, &out.CreatedAt,
&out.UpdatedAt, &out.UpdatedAt,
......
...@@ -125,7 +125,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) { ...@@ -125,7 +125,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.True(t, bound, "invitee must bind to inviter") require.True(t, bound, "invitee must bind to inviter")
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5) applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
require.NoError(t, err) require.NoError(t, err)
require.True(t, applied, "AccrueQuota must report applied=true") require.True(t, applied, "AccrueQuota must report applied=true")
......
...@@ -716,6 +716,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -716,6 +716,9 @@ func TestAPIContracts(t *testing.T) {
"default_concurrency": 5, "default_concurrency": 5,
"default_balance": 1.25, "default_balance": 1.25,
"affiliate_rebate_rate": 20, "affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0, "default_user_rpm_limit": 0,
"default_subscriptions": [], "default_subscriptions": [],
"enable_model_fallback": false, "enable_model_fallback": false,
...@@ -898,6 +901,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -898,6 +901,9 @@ func TestAPIContracts(t *testing.T) {
"default_concurrency": 0, "default_concurrency": 0,
"default_balance": 0, "default_balance": 0,
"affiliate_rebate_rate": 20, "affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0, "default_user_rpm_limit": 0,
"default_subscriptions": [], "default_subscriptions": [],
"enable_model_fallback": false, "enable_model_fallback": false,
......
...@@ -65,6 +65,7 @@ type AffiliateSummary struct { ...@@ -65,6 +65,7 @@ type AffiliateSummary struct {
InviterID *int64 `json:"inviter_id,omitempty"` InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"` AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"` AffQuota float64 `json:"aff_quota"`
AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"` AffHistoryQuota float64 `json:"aff_history_quota"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
...@@ -75,6 +76,7 @@ type AffiliateInvitee struct { ...@@ -75,6 +76,7 @@ type AffiliateInvitee struct {
Email string `json:"email"` Email string `json:"email"`
Username string `json:"username"` Username string `json:"username"`
CreatedAt *time.Time `json:"created_at,omitempty"` CreatedAt *time.Time `json:"created_at,omitempty"`
TotalRebate float64 `json:"total_rebate"`
} }
type AffiliateDetail struct { type AffiliateDetail struct {
...@@ -83,6 +85,7 @@ type AffiliateDetail struct { ...@@ -83,6 +85,7 @@ type AffiliateDetail struct {
InviterID *int64 `json:"inviter_id,omitempty"` InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"` AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"` AffQuota float64 `json:"aff_quota"`
AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"` AffHistoryQuota float64 `json:"aff_history_quota"`
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例: // EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。 // 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
...@@ -95,7 +98,9 @@ type AffiliateRepository interface { ...@@ -95,7 +98,9 @@ type AffiliateRepository interface {
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
...@@ -160,6 +165,12 @@ func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64 ...@@ -160,6 +165,12 @@ func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64
} }
func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) { func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
// Lazy thaw: move any matured frozen quota to available before reading.
if s != nil && s.repo != nil {
// best-effort: thaw failure is non-fatal
_, _ = s.repo.ThawFrozenQuota(ctx, userID)
}
summary, err := s.EnsureUserAffiliate(ctx, userID) summary, err := s.EnsureUserAffiliate(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -174,6 +185,7 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) ...@@ -174,6 +185,7 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64)
InviterID: summary.InviterID, InviterID: summary.InviterID,
AffCount: summary.AffCount, AffCount: summary.AffCount,
AffQuota: summary.AffQuota, AffQuota: summary.AffQuota,
AffFrozenQuota: summary.AffFrozenQuota,
AffHistoryQuota: summary.AffHistoryQuota, AffHistoryQuota: summary.AffHistoryQuota,
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary), EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
Invitees: invitees, Invitees: invitees,
...@@ -250,13 +262,43 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID ...@@ -250,13 +262,43 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
if err != nil { if err != nil {
return 0, err return 0, err
} }
// 有效期检查:超过返利有效期后不再产生返利
if s.settingService != nil {
if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
return 0, nil
}
}
}
rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary) rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8) rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
if rebate <= 0 { if rebate <= 0 {
return 0, nil return 0, nil
} }
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate) // 单人上限检查:精确截断到剩余额度
if s.settingService != nil {
if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
if err != nil {
return 0, err
}
if existing >= perInviteeCap {
return 0, nil
}
if remaining := perInviteeCap - existing; rebate > remaining {
rebate = roundTo(remaining, 8)
}
}
}
var freezeHours int
if s.settingService != nil {
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
}
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
if err != nil { if err != nil {
return 0, err return 0, err
} }
......
...@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount( ...@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
user *User, user *User,
invitationCode string, invitationCode string,
signupSource string, signupSource string,
affiliateCode string,
) error { ) error {
if s == nil || user == nil || user.ID <= 0 { if s == nil || user == nil || user.ID <= 0 {
return ErrServiceUnavailable return ErrServiceUnavailable
...@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount( ...@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
s.updateOAuthSignupSource(ctx, user.ID, signupSource) s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
return nil return nil
} }
......
...@@ -563,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -563,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。 // LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。 // 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。 // invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) { // affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用 // 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil { if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured") return nil, nil, errors.New("refresh token cache not configured")
...@@ -666,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -666,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false) s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
} }
} else { } else {
if err := s.userRepo.Create(ctx, newUser); err != nil { if err := s.userRepo.Create(ctx, newUser); err != nil {
...@@ -683,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -683,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false) s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil { if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid return nil, nil, ErrInvitationCodeInvalid
...@@ -777,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource ...@@ -777,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
} }
} }
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
// for an OAuth-registered user. Failures are logged but never block registration.
func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
if s.affiliateService == nil || userID <= 0 {
return
}
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
}
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
}
}
}
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 { if user == nil || user.ID <= 0 {
return return
......
...@@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa ...@@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
service.defaultSubAssigner = assigner service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{} service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "") tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, tokenPair) require.NotNil(t, tokenPair)
require.NotNil(t, user) require.NotNil(t, user)
...@@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA ...@@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
service.defaultSubAssigner = assigner service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{} service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "") tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, tokenPair) require.NotNil(t, tokenPair)
require.Equal(t, existing.ID, user.ID) require.Equal(t, existing.ID, user.ID)
......
...@@ -24,6 +24,11 @@ const ( ...@@ -24,6 +24,11 @@ const (
AffiliateRebateRateMin = 0.0 AffiliateRebateRateMin = 0.0
AffiliateRebateRateMax = 100.0 AffiliateRebateRateMax = 100.0
AffiliateEnabledDefault = false // 邀请返利总开关默认关闭 AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容)
AffiliateRebateFreezeHoursMax = 720 // 最大 30 天
AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
AffiliateRebateDurationDaysMax = 3650 // ~10 年
AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
) )
// Platform constants // Platform constants
...@@ -98,6 +103,9 @@ const ( ...@@ -98,6 +103,9 @@ const (
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关 SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100) SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
// 邮件服务设置 // 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment