Unverified Commit 3f05ef2a authored by Oliver Li's avatar Oliver Li Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into vertex

parents 6d11f9ed c056db74
docs/claude-relay-service/
.codex
# ===================
# Go 后端
......
......@@ -186,6 +186,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
AffiliateRebateRate: settings.AffiliateRebateRate,
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
......@@ -342,6 +345,9 @@ type UpdateSettingsRequest struct {
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
......@@ -485,6 +491,33 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if affiliateRebateRate > service.AffiliateRebateRateMax {
affiliateRebateRate = service.AffiliateRebateRateMax
}
affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours
if req.AffiliateRebateFreezeHours != nil {
affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours
}
if affiliateRebateFreezeHours < 0 {
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault
}
if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax {
affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax
}
affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays
if req.AffiliateRebateDurationDays != nil {
affiliateRebateDurationDays = *req.AffiliateRebateDurationDays
}
if affiliateRebateDurationDays < 0 {
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault
}
if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax {
affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax
}
affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap
if req.AffiliateRebatePerInviteeCap != nil {
affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap
}
if affiliateRebatePerInviteeCap < 0 {
affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
}
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
......@@ -1137,6 +1170,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
AffiliateRebateRate: affiliateRebateRate,
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
AffiliateRebateDurationDays: affiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
......@@ -1458,6 +1494,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
......@@ -1768,6 +1807,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AffiliateRebateRate != after.AffiliateRebateRate {
changed = append(changed, "affiliate_rebate_rate")
}
if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours {
changed = append(changed, "affiliate_rebate_freeze_hours")
}
if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays {
changed = append(changed, "affiliate_rebate_duration_days")
}
if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
changed = append(changed, "affiliate_rebate_per_invitee_cap")
}
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
......
......@@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
type completeLinuxDoOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
......@@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return
......
......@@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct {
VerifyCode string `json:"verify_code,omitempty"`
Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
......@@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
strings.TrimSpace(req.AffCode),
); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
......
......@@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
type completeOIDCOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
......@@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return
......
......@@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService
type completeWeChatOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
......@@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return
......
......@@ -109,6 +109,9 @@ type SystemSettings struct {
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
......
......@@ -25,6 +25,7 @@ const (
easypayStatusPaid = 1
easypayHTTPTimeout = 10 * time.Second
maxEasypayResponseSize = 1 << 20 // 1MB
maxEasypayErrorSummary = 512
tradeStatusSuccess = "TRADE_SUCCESS"
signTypeMD5 = "MD5"
paymentModePopup = "popup"
......@@ -42,17 +43,55 @@ type EasyPay struct {
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
if config[k] == "" {
if strings.TrimSpace(config[k]) == "" {
return nil, fmt.Errorf("easypay config missing required key: %s", k)
}
}
cfg := make(map[string]string, len(config))
for k, v := range config {
cfg[k] = v
}
cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
return &EasyPay{
instanceID: instanceID,
config: config,
config: cfg,
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
}, nil
}
func normalizeEasyPayAPIBase(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return ""
}
if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.RawPath = ""
parsed.Path = trimEasyPayEndpointPath(parsed.Path)
return strings.TrimRight(parsed.String(), "/")
}
return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
}
func trimEasyPayEndpointPath(path string) string {
path = strings.TrimRight(strings.TrimSpace(path), "/")
lower := strings.ToLower(path)
for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
if strings.HasSuffix(lower, endpoint) {
return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
}
}
return path
}
func (e *EasyPay) apiBase() string {
if e == nil {
return ""
}
return normalizeEasyPayAPIBase(e.config["apiBase"])
}
func (e *EasyPay) Name() string { return "EasyPay" }
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
......@@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
for k, v := range params {
q.Set(k, v)
}
base := strings.TrimRight(e.config["apiBase"], "/")
payURL := base + "/submit.php?" + q.Encode()
payURL := e.apiBase() + "/submit.php?" + q.Encode()
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
}
......@@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5
body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
if err != nil {
return nil, fmt.Errorf("easypay create: %w", err)
}
......@@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
"act": "order", "pid": e.config["pid"],
"key": e.config["pkey"], "out_trade_no": tradeNo,
}
body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
body, err := e.post(ctx, e.apiBase()+"/api.php", params)
if err != nil {
return nil, fmt.Errorf("easypay query: %w", err)
}
......@@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
}
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
params := map[string]string{
"pid": e.config["pid"], "key": e.config["pkey"],
"trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
attempts := e.refundAttempts(req)
if len(attempts) == 0 {
return nil, fmt.Errorf("easypay refund missing order identifier")
}
body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
var firstErr error
for i, attempt := range attempts {
body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
if err != nil {
return nil, fmt.Errorf("easypay refund: %w", err)
return nil, fmt.Errorf("easypay refund request: %w", err)
}
if err := parseEasyPayRefundResponse(status, body); err != nil {
if firstErr == nil {
firstErr = err
}
if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
continue
}
return nil, err
}
return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
}
return nil, firstErr
}
type easyPayRefundAttempt struct {
params map[string]string
refundID string
}
func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
base := map[string]string{
"pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
}
var attempts []easyPayRefundAttempt
if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
params := cloneStringMap(base)
params["out_trade_no"] = orderID
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
}
if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
params := cloneStringMap(base)
params["trade_no"] = tradeNo
attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
}
return attempts
}
func cloneStringMap(in map[string]string) map[string]string {
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func isEasyPayRefundOrderNotFound(err error) bool {
if err == nil {
return false
}
msg := err.Error()
lower := strings.ToLower(msg)
return strings.Contains(msg, "订单编号不存在") ||
strings.Contains(msg, "订单不存在") ||
strings.Contains(lower, "order not found") ||
strings.Contains(lower, "not exist")
}
func parseEasyPayRefundResponse(status int, body []byte) error {
summary := summarizeEasyPayResponse(body)
if status < http.StatusOK || status >= http.StatusMultipleChoices {
return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
}
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
}
lower := strings.ToLower(trimmed)
if strings.HasPrefix(lower, "<!doctype html") || strings.HasPrefix(lower, "<html") ||
(strings.HasPrefix(lower, "<") && strings.Contains(lower, "html")) {
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
}
var resp struct {
Code int `json:"code"`
Code any `json:"code"`
Msg string `json:"msg"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("easypay parse refund: %w", err)
return fmt.Errorf("easypay refund non-JSON response (HTTP %d): %s", status, summary)
}
if resp.Code != easypayCodeSuccess {
return nil, fmt.Errorf("easypay refund failed: %s", resp.Msg)
if !easyPayResponseCodeIsSuccess(resp.Code) {
msg := strings.TrimSpace(resp.Msg)
if msg == "" {
msg = summary
}
return fmt.Errorf("easypay refund failed (HTTP %d): %s", status, msg)
}
return nil
}
func easyPayResponseCodeIsSuccess(code any) bool {
switch v := code.(type) {
case float64:
return int(v) == easypayCodeSuccess
case string:
n, err := strconv.Atoi(strings.TrimSpace(v))
return err == nil && n == easypayCodeSuccess
default:
return false
}
return &payment.RefundResponse{RefundID: req.TradeNo, Status: payment.ProviderStatusSuccess}, nil
}
func summarizeEasyPayResponse(body []byte) string {
summary := strings.Join(strings.Fields(string(body)), " ")
if summary == "" {
return "<empty>"
}
if len(summary) > maxEasypayErrorSummary {
return summary[:maxEasypayErrorSummary] + "..."
}
return summary
}
func (e *EasyPay) resolveCID(paymentType string) string {
......@@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string {
}
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
body, _, err := e.postRaw(ctx, endpoint, params)
return body, err
}
func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
return nil, 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := e.httpClient.Do(req)
client := e.httpClient
if client == nil {
client = &http.Client{Timeout: easypayHTTPTimeout}
}
resp, err := client.Do(req)
if err != nil {
return nil, err
return nil, 0, err
}
defer func() { _ = resp.Body.Close() }()
return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
if err != nil {
return nil, resp.StatusCode, err
}
return body, resp.StatusCode, nil
}
func easyPaySign(params map[string]string, pkey string) string {
......
package provider
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeEasyPayAPIBase(t *testing.T) {
t.Parallel()
tests := []struct {
input string
want string
}{
{input: "https://zpayz.cn", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
{input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
t.Parallel()
var gotPath string
var gotQuery url.Values
var gotForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotQuery = r.URL.Query()
if err := r.ParseForm(); err != nil {
t.Errorf("ParseForm: %v", err)
}
gotForm = r.PostForm
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL+"/mapi.php")
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
TradeNo: "trade-123",
OrderID: "out-456",
Amount: "1.50",
})
if err != nil {
t.Fatalf("Refund returned error: %v", err)
}
if resp == nil || resp.Status != payment.ProviderStatusSuccess {
t.Fatalf("Refund response = %+v, want success", resp)
}
if gotPath != "/api.php" {
t.Fatalf("refund path = %q, want /api.php", gotPath)
}
if gotQuery.Get("act") != "refund" {
t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
}
for key, want := range map[string]string{
"pid": "pid-1",
"key": "pkey-1",
"out_trade_no": "out-456",
"money": "1.50",
} {
if got := gotForm.Get(key); got != want {
t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
}
}
if got := gotForm.Get("trade_no"); got != "" {
t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
}
}
func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
t.Parallel()
var gotForms []url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api.php" {
t.Errorf("refund path = %q, want /api.php", r.URL.Path)
}
if r.URL.Query().Get("act") != "refund" {
t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
}
if err := r.ParseForm(); err != nil {
t.Errorf("ParseForm: %v", err)
}
gotForms = append(gotForms, r.PostForm)
w.Header().Set("Content-Type", "application/json")
if len(gotForms) == 1 {
_, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
return
}
_, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL+"/mapi.php")
resp, err := provider.Refund(context.Background(), payment.RefundRequest{
TradeNo: "trade-123",
OrderID: "out-456",
Amount: "1.50",
})
if err != nil {
t.Fatalf("Refund returned error: %v", err)
}
if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
}
if len(gotForms) != 2 {
t.Fatalf("refund attempts = %d, want 2", len(gotForms))
}
if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
}
if got := gotForms[0].Get("trade_no"); got != "" {
t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
}
if got := gotForms[1].Get("trade_no"); got != "trade-123" {
t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
}
if got := gotForms[1].Get("out_trade_no"); got != "" {
t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
}
}
func TestEasyPayRefundResponseErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
body string
want string
}{
{name: "html response", statusCode: http.StatusOK, body: "<html>bad config</html>", want: "non-JSON response (HTTP 200): <html>bad config</html>"},
{name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
{name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
{name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): <empty>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(tt.statusCode)
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()
provider := newTestEasyPay(t, server.URL)
_, err := provider.Refund(context.Background(), payment.RefundRequest{
OrderID: "out-456",
Amount: "1.50",
})
if err == nil {
t.Fatal("Refund returned nil error")
}
if !strings.Contains(err.Error(), tt.want) {
t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
}
})
}
}
func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
t.Helper()
provider, err := NewEasyPay("test-instance", map[string]string{
"pid": "pid-1",
"pkey": "pkey-1",
"apiBase": apiBase,
"notifyUrl": "https://example.com/notify",
"returnUrl": "https://example.com/return",
})
if err != nil {
t.Fatalf("NewEasyPay: %v", err)
}
return provider
}
......@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cached",
Model: "gpt-5.2",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "Cached response"},
},
},
},
Usage: &ResponsesUsage{
InputTokens: 54006,
OutputTokens: 123,
TotalTokens: 54129,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 50688,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
assert.Equal(t, 3318, anth.Usage.InputTokens)
assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
assert.Equal(t, 123, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cached_clamp",
Model: "gpt-5.2",
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 100,
OutputTokens: 5,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 150,
},
},
}
anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
assert.Equal(t, 0, anth.Usage.InputTokens)
assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_456",
......@@ -343,6 +392,36 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type)
}
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
}, state)
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 54006,
OutputTokens: 123,
TotalTokens: 54129,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 50688,
},
},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, 3318, events[0].Usage.InputTokens)
assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
assert.Equal(t, 123, events[0].Usage.OutputTokens)
assert.Equal(t, "message_stop", events[1].Type)
}
func TestStreamingToolCall(t *testing.T) {
state := NewResponsesEventToAnthropicState()
......
......@@ -84,16 +84,32 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
if resp.Usage != nil {
out.Usage = AnthropicUsage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
}
if resp.Usage.InputTokensDetails != nil {
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
return out
}
func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
if usage == nil {
return AnthropicUsage{}
}
cachedTokens := 0
if usage.InputTokensDetails != nil {
cachedTokens = usage.InputTokensDetails.CachedTokens
}
inputTokens := usage.InputTokens - cachedTokens
if inputTokens < 0 {
inputTokens = 0
}
return out
return AnthropicUsage{
InputTokens: inputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: cachedTokens,
}
}
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
......@@ -466,11 +482,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
stopReason := "end_turn"
if evt.Response != nil {
if evt.Response.Usage != nil {
state.InputTokens = evt.Response.Usage.InputTokens
state.OutputTokens = evt.Response.Usage.OutputTokens
if evt.Response.Usage.InputTokensDetails != nil {
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
}
usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
state.InputTokens = usage.InputTokens
state.OutputTokens = usage.OutputTokens
state.CacheReadInputTokens = usage.CacheReadInputTokens
}
switch evt.Response.Status {
case "incomplete":
......
......@@ -86,17 +86,21 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
return bound, nil
}
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) {
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
if amount <= 0 {
return false, nil
}
var applied bool
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
res, err := txClient.ExecContext(txCtx,
"UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2",
amount, inviterID,
)
// freezeHours > 0: add to frozen quota; == 0: add to available quota directly
var updateSQL string
if freezeHours > 0 {
updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
} else {
updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
}
res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID)
if err != nil {
return err
}
......@@ -106,11 +110,20 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
return nil
}
if freezeHours > 0 {
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
inviterID, amount, inviteeUserID, freezeHours); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
}
} else {
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
}
}
applied = true
return nil
......@@ -121,6 +134,76 @@ VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID);
return applied, nil
}
func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx,
`SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`,
inviterID, inviteeUserID)
if err != nil {
return 0, fmt.Errorf("query accrued rebate from invitee: %w", err)
}
defer func() { _ = rows.Close() }()
var total float64
if rows.Next() {
if err := rows.Scan(&total); err != nil {
return 0, err
}
}
return total, rows.Close()
}
func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) {
var thawed float64
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
var err error
thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID)
return err
})
return thawed, err
}
// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx.
func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) {
rows, err := txClient.QueryContext(txCtx, `
WITH matured AS (
UPDATE user_affiliate_ledger
SET frozen_until = NULL, updated_at = NOW()
WHERE user_id = $1
AND frozen_until IS NOT NULL
AND frozen_until <= NOW()
RETURNING amount
)
SELECT COALESCE(SUM(amount), 0) FROM matured`, userID)
if err != nil {
return 0, fmt.Errorf("thaw frozen quota: %w", err)
}
defer func() { _ = rows.Close() }()
var thawed float64
if rows.Next() {
if err := rows.Scan(&thawed); err != nil {
return 0, err
}
}
if err := rows.Close(); err != nil {
return 0, err
}
if thawed <= 0 {
return 0, nil
}
_, err = txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_quota = aff_quota + $1,
aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0),
updated_at = NOW()
WHERE user_id = $2`, thawed, userID)
if err != nil {
return 0, fmt.Errorf("move thawed quota: %w", err)
}
return thawed, nil
}
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
var transferred float64
var newBalance float64
......@@ -130,6 +213,11 @@ func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID
return err
}
// Thaw any matured frozen quota before transfer.
if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil {
return fmt.Errorf("thaw before transfer: %w", err)
}
rows, err := txClient.QueryContext(txCtx, `
WITH claimed AS (
SELECT aff_quota::double precision AS amount
......@@ -211,10 +299,16 @@ func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64,
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ua.created_at
ua.created_at,
COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate
FROM user_affiliates ua
LEFT JOIN users u ON u.id = ua.user_id
LEFT JOIN user_affiliate_ledger ual
ON ual.user_id = $1
AND ual.source_user_id = ua.user_id
AND ual.action = 'accrue'
WHERE ua.inviter_id = $1
GROUP BY ua.user_id, u.email, u.username, ua.created_at
ORDER BY ua.created_at DESC
LIMIT $2`, inviterID, limit)
if err != nil {
......@@ -226,7 +320,7 @@ LIMIT $2`, inviterID, limit)
for rows.Next() {
var item service.AffiliateInvitee
var createdAt time.Time
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil {
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil {
return nil, err
}
item.CreatedAt = &createdAt
......@@ -299,6 +393,7 @@ SELECT user_id,
inviter_id,
aff_count,
aff_quota::double precision,
aff_frozen_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
......@@ -326,6 +421,7 @@ WHERE user_id = $1`, userID)
&inviterID,
&out.AffCount,
&out.AffQuota,
&out.AffFrozenQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
......@@ -351,6 +447,7 @@ SELECT user_id,
inviter_id,
aff_count,
aff_quota::double precision,
aff_frozen_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
......@@ -380,6 +477,7 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
&inviterID,
&out.AffCount,
&out.AffQuota,
&out.AffFrozenQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
......
......@@ -125,7 +125,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
require.NoError(t, err)
require.True(t, bound, "invitee must bind to inviter")
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5)
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
require.NoError(t, err)
require.True(t, applied, "AccrueQuota must report applied=true")
......
......@@ -716,6 +716,9 @@ func TestAPIContracts(t *testing.T) {
"default_concurrency": 5,
"default_balance": 1.25,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
......@@ -898,6 +901,9 @@ func TestAPIContracts(t *testing.T) {
"default_concurrency": 0,
"default_balance": 0,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
"affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
......
......@@ -65,6 +65,7 @@ type AffiliateSummary struct {
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
......@@ -75,6 +76,7 @@ type AffiliateInvitee struct {
Email string `json:"email"`
Username string `json:"username"`
CreatedAt *time.Time `json:"created_at,omitempty"`
TotalRebate float64 `json:"total_rebate"`
}
type AffiliateDetail struct {
......@@ -83,6 +85,7 @@ type AffiliateDetail struct {
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
......@@ -95,7 +98,9 @@ type AffiliateRepository interface {
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error)
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
......@@ -160,6 +165,12 @@ func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64
}
func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
// Lazy thaw: move any matured frozen quota to available before reading.
if s != nil && s.repo != nil {
// best-effort: thaw failure is non-fatal
_, _ = s.repo.ThawFrozenQuota(ctx, userID)
}
summary, err := s.EnsureUserAffiliate(ctx, userID)
if err != nil {
return nil, err
......@@ -174,6 +185,7 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64)
InviterID: summary.InviterID,
AffCount: summary.AffCount,
AffQuota: summary.AffQuota,
AffFrozenQuota: summary.AffFrozenQuota,
AffHistoryQuota: summary.AffHistoryQuota,
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
Invitees: invitees,
......@@ -250,13 +262,43 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
if err != nil {
return 0, err
}
// 有效期检查:超过返利有效期后不再产生返利
if s.settingService != nil {
if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
return 0, nil
}
}
}
rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
if rebate <= 0 {
return 0, nil
}
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate)
// 单人上限检查:精确截断到剩余额度
if s.settingService != nil {
if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
if err != nil {
return 0, err
}
if existing >= perInviteeCap {
return 0, nil
}
if remaining := perInviteeCap - existing; rebate > remaining {
rebate = roundTo(remaining, 8)
}
}
}
var freezeHours int
if s.settingService != nil {
freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
}
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
if err != nil {
return 0, err
}
......
......@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
user *User,
invitationCode string,
signupSource string,
affiliateCode string,
) error {
if s == nil || user == nil || user.ID <= 0 {
return ErrServiceUnavailable
......@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
return nil
}
......
......@@ -563,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
......@@ -666,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
......@@ -683,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
......@@ -777,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
}
}
// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
// for an OAuth-registered user. Failures are logged but never block registration.
func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
if s.affiliateService == nil || userID <= 0 {
return
}
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
}
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
}
}
}
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
......
......@@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
......@@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.Equal(t, existing.ID, user.ID)
......
......@@ -24,6 +24,11 @@ const (
AffiliateRebateRateMin = 0.0
AffiliateRebateRateMax = 100.0
AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容)
AffiliateRebateFreezeHoursMax = 720 // 最大 30 天
AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
AffiliateRebateDurationDaysMax = 3650 // ~10 年
AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
)
// Platform constants
......@@ -98,6 +103,9 @@ const (
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment