Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a04ae28a
Commit
a04ae28a
authored
Apr 13, 2026
by
陈曦
Browse files
merge v0.1.111
parents
68f67198
ad64190b
Changes
302
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
302 of 302+
files are displayed.
Plain diff
Email patch
backend/internal/payment/provider/stripe.go
0 → 100644
View file @
a04ae28a
package
provider
import
(
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/payment"
stripe
"github.com/stripe/stripe-go/v85"
"github.com/stripe/stripe-go/v85/webhook"
)
// Stripe constants.
const
(
stripeCurrency
=
"cny"
stripeEventPaymentSuccess
=
"payment_intent.succeeded"
stripeEventPaymentFailed
=
"payment_intent.payment_failed"
)
// Stripe implements the payment.CancelableProvider interface for Stripe payments.
type
Stripe
struct
{
instanceID
string
config
map
[
string
]
string
mu
sync
.
Mutex
initialized
bool
sc
*
stripe
.
Client
}
// NewStripe creates a new Stripe provider instance.
func
NewStripe
(
instanceID
string
,
config
map
[
string
]
string
)
(
*
Stripe
,
error
)
{
if
config
[
"secretKey"
]
==
""
{
return
nil
,
fmt
.
Errorf
(
"stripe config missing required key: secretKey"
)
}
return
&
Stripe
{
instanceID
:
instanceID
,
config
:
config
,
},
nil
}
func
(
s
*
Stripe
)
ensureInit
()
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
!
s
.
initialized
{
s
.
sc
=
stripe
.
NewClient
(
s
.
config
[
"secretKey"
])
s
.
initialized
=
true
}
}
// GetPublishableKey returns the publishable key for frontend use.
func
(
s
*
Stripe
)
GetPublishableKey
()
string
{
return
s
.
config
[
"publishableKey"
]
}
func
(
s
*
Stripe
)
Name
()
string
{
return
"Stripe"
}
func
(
s
*
Stripe
)
ProviderKey
()
string
{
return
payment
.
TypeStripe
}
func
(
s
*
Stripe
)
SupportedTypes
()
[]
payment
.
PaymentType
{
return
[]
payment
.
PaymentType
{
payment
.
TypeStripe
}
}
// stripePaymentMethodTypes maps our PaymentType to Stripe payment_method_types.
var
stripePaymentMethodTypes
=
map
[
string
][]
string
{
payment
.
TypeCard
:
{
"card"
},
payment
.
TypeAlipay
:
{
"alipay"
},
payment
.
TypeWxpay
:
{
"wechat_pay"
},
payment
.
TypeLink
:
{
"link"
},
}
// CreatePayment creates a Stripe PaymentIntent.
func
(
s
*
Stripe
)
CreatePayment
(
ctx
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
s
.
ensureInit
()
amountInCents
,
err
:=
payment
.
YuanToFen
(
req
.
Amount
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe create payment: %w"
,
err
)
}
// Collect all Stripe payment_method_types from the instance's configured sub-methods
methods
:=
resolveStripeMethodTypes
(
req
.
InstanceSubMethods
)
pmTypes
:=
make
([]
*
string
,
len
(
methods
))
for
i
,
m
:=
range
methods
{
pmTypes
[
i
]
=
stripe
.
String
(
m
)
}
params
:=
&
stripe
.
PaymentIntentCreateParams
{
Amount
:
stripe
.
Int64
(
amountInCents
),
Currency
:
stripe
.
String
(
stripeCurrency
),
PaymentMethodTypes
:
pmTypes
,
Description
:
stripe
.
String
(
req
.
Subject
),
Metadata
:
map
[
string
]
string
{
"orderId"
:
req
.
OrderID
},
}
// WeChat Pay requires payment_method_options with client type
if
hasStripeMethod
(
methods
,
"wechat_pay"
)
{
params
.
PaymentMethodOptions
=
&
stripe
.
PaymentIntentCreatePaymentMethodOptionsParams
{
WeChatPay
:
&
stripe
.
PaymentIntentCreatePaymentMethodOptionsWeChatPayParams
{
Client
:
stripe
.
String
(
"web"
),
},
}
}
params
.
SetIdempotencyKey
(
fmt
.
Sprintf
(
"pi-%s"
,
req
.
OrderID
))
params
.
Context
=
ctx
pi
,
err
:=
s
.
sc
.
V1PaymentIntents
.
Create
(
ctx
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe create payment: %w"
,
err
)
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
pi
.
ID
,
ClientSecret
:
pi
.
ClientSecret
,
},
nil
}
// QueryOrder retrieves a PaymentIntent by ID.
func
(
s
*
Stripe
)
QueryOrder
(
ctx
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
s
.
ensureInit
()
pi
,
err
:=
s
.
sc
.
V1PaymentIntents
.
Retrieve
(
ctx
,
tradeNo
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe query order: %w"
,
err
)
}
status
:=
payment
.
ProviderStatusPending
switch
pi
.
Status
{
case
stripe
.
PaymentIntentStatusSucceeded
:
status
=
payment
.
ProviderStatusPaid
case
stripe
.
PaymentIntentStatusCanceled
:
status
=
payment
.
ProviderStatusFailed
}
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
pi
.
ID
,
Status
:
status
,
Amount
:
payment
.
FenToYuan
(
pi
.
Amount
),
},
nil
}
// VerifyNotification verifies a Stripe webhook event.
func
(
s
*
Stripe
)
VerifyNotification
(
_
context
.
Context
,
rawBody
string
,
headers
map
[
string
]
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
s
.
ensureInit
()
webhookSecret
:=
s
.
config
[
"webhookSecret"
]
if
webhookSecret
==
""
{
return
nil
,
fmt
.
Errorf
(
"stripe webhookSecret not configured"
)
}
sig
:=
headers
[
"stripe-signature"
]
if
sig
==
""
{
return
nil
,
fmt
.
Errorf
(
"stripe notification missing stripe-signature header"
)
}
event
,
err
:=
webhook
.
ConstructEvent
([]
byte
(
rawBody
),
sig
,
webhookSecret
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe verify notification: %w"
,
err
)
}
switch
event
.
Type
{
case
stripeEventPaymentSuccess
:
return
parseStripePaymentIntent
(
&
event
,
payment
.
ProviderStatusSuccess
,
rawBody
)
case
stripeEventPaymentFailed
:
return
parseStripePaymentIntent
(
&
event
,
payment
.
ProviderStatusFailed
,
rawBody
)
}
return
nil
,
nil
}
func
parseStripePaymentIntent
(
event
*
stripe
.
Event
,
status
string
,
rawBody
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
var
pi
stripe
.
PaymentIntent
if
err
:=
json
.
Unmarshal
(
event
.
Data
.
Raw
,
&
pi
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe parse payment_intent: %w"
,
err
)
}
return
&
payment
.
PaymentNotification
{
TradeNo
:
pi
.
ID
,
OrderID
:
pi
.
Metadata
[
"orderId"
],
Amount
:
payment
.
FenToYuan
(
pi
.
Amount
),
Status
:
status
,
RawData
:
rawBody
,
},
nil
}
// Refund creates a Stripe refund.
func
(
s
*
Stripe
)
Refund
(
ctx
context
.
Context
,
req
payment
.
RefundRequest
)
(
*
payment
.
RefundResponse
,
error
)
{
s
.
ensureInit
()
amountInCents
,
err
:=
payment
.
YuanToFen
(
req
.
Amount
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe refund: %w"
,
err
)
}
params
:=
&
stripe
.
RefundCreateParams
{
PaymentIntent
:
stripe
.
String
(
req
.
TradeNo
),
Amount
:
stripe
.
Int64
(
amountInCents
),
Reason
:
stripe
.
String
(
string
(
stripe
.
RefundReasonRequestedByCustomer
)),
}
params
.
Context
=
ctx
r
,
err
:=
s
.
sc
.
V1Refunds
.
Create
(
ctx
,
params
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"stripe refund: %w"
,
err
)
}
refundStatus
:=
payment
.
ProviderStatusPending
if
r
.
Status
==
stripe
.
RefundStatusSucceeded
{
refundStatus
=
payment
.
ProviderStatusSuccess
}
return
&
payment
.
RefundResponse
{
RefundID
:
r
.
ID
,
Status
:
refundStatus
,
},
nil
}
// resolveStripeMethodTypes converts instance supported_types (comma-separated)
// into Stripe API payment_method_types. Falls back to ["card"] if empty.
func
resolveStripeMethodTypes
(
instanceSubMethods
string
)
[]
string
{
if
instanceSubMethods
==
""
{
return
[]
string
{
"card"
}
}
var
methods
[]
string
for
_
,
t
:=
range
strings
.
Split
(
instanceSubMethods
,
","
)
{
t
=
strings
.
TrimSpace
(
t
)
if
mapped
,
ok
:=
stripePaymentMethodTypes
[
t
];
ok
{
methods
=
append
(
methods
,
mapped
...
)
}
}
if
len
(
methods
)
==
0
{
return
[]
string
{
"card"
}
}
return
methods
}
// hasStripeMethod checks if the given Stripe method list contains the target method.
func
hasStripeMethod
(
methods
[]
string
,
target
string
)
bool
{
for
_
,
m
:=
range
methods
{
if
m
==
target
{
return
true
}
}
return
false
}
// CancelPayment cancels a pending PaymentIntent.
func
(
s
*
Stripe
)
CancelPayment
(
ctx
context
.
Context
,
tradeNo
string
)
error
{
s
.
ensureInit
()
_
,
err
:=
s
.
sc
.
V1PaymentIntents
.
Cancel
(
ctx
,
tradeNo
,
nil
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"stripe cancel payment: %w"
,
err
)
}
return
nil
}
// Ensure interface compliance.
var
(
_
payment
.
Provider
=
(
*
Stripe
)(
nil
)
_
payment
.
CancelableProvider
=
(
*
Stripe
)(
nil
)
)
backend/internal/payment/provider/wxpay.go
0 → 100644
View file @
a04ae28a
package
provider
import
(
"bytes"
"context"
"crypto/rsa"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/wechatpay-apiv3/wechatpay-go/core"
"github.com/wechatpay-apiv3/wechatpay-go/core/auth/verifiers"
"github.com/wechatpay-apiv3/wechatpay-go/core/notify"
"github.com/wechatpay-apiv3/wechatpay-go/core/option"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
"github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
"github.com/wechatpay-apiv3/wechatpay-go/utils"
)
// WeChat Pay constants.
const
(
wxpayCurrency
=
"CNY"
wxpayH5Type
=
"Wap"
)
// WeChat Pay trade states.
const
(
wxpayTradeStateSuccess
=
"SUCCESS"
wxpayTradeStateRefund
=
"REFUND"
wxpayTradeStateClosed
=
"CLOSED"
wxpayTradeStatePayError
=
"PAYERROR"
)
// WeChat Pay notification event types.
const
(
wxpayEventTransactionSuccess
=
"TRANSACTION.SUCCESS"
)
// WeChat Pay error codes.
const
(
wxpayErrNoAuth
=
"NO_AUTH"
)
type
Wxpay
struct
{
instanceID
string
config
map
[
string
]
string
mu
sync
.
Mutex
coreClient
*
core
.
Client
notifyHandler
*
notify
.
Handler
}
func
NewWxpay
(
instanceID
string
,
config
map
[
string
]
string
)
(
*
Wxpay
,
error
)
{
required
:=
[]
string
{
"appId"
,
"mchId"
,
"privateKey"
,
"apiV3Key"
,
"publicKey"
,
"publicKeyId"
,
"certSerial"
}
for
_
,
k
:=
range
required
{
if
config
[
k
]
==
""
{
return
nil
,
fmt
.
Errorf
(
"wxpay config missing required key: %s"
,
k
)
}
}
if
len
(
config
[
"apiV3Key"
])
!=
32
{
return
nil
,
fmt
.
Errorf
(
"wxpay apiV3Key must be exactly 32 bytes, got %d"
,
len
(
config
[
"apiV3Key"
]))
}
return
&
Wxpay
{
instanceID
:
instanceID
,
config
:
config
},
nil
}
func
(
w
*
Wxpay
)
Name
()
string
{
return
"Wxpay"
}
func
(
w
*
Wxpay
)
ProviderKey
()
string
{
return
payment
.
TypeWxpay
}
func
(
w
*
Wxpay
)
SupportedTypes
()
[]
payment
.
PaymentType
{
return
[]
payment
.
PaymentType
{
payment
.
TypeWxpayDirect
}
}
func
formatPEM
(
key
,
keyType
string
)
string
{
key
=
strings
.
TrimSpace
(
key
)
if
strings
.
HasPrefix
(
key
,
"-----BEGIN"
)
{
return
key
}
return
fmt
.
Sprintf
(
"-----BEGIN %s-----
\n
%s
\n
-----END %s-----"
,
keyType
,
key
,
keyType
)
}
func
(
w
*
Wxpay
)
ensureClient
()
(
*
core
.
Client
,
error
)
{
w
.
mu
.
Lock
()
defer
w
.
mu
.
Unlock
()
if
w
.
coreClient
!=
nil
{
return
w
.
coreClient
,
nil
}
privateKey
,
publicKey
,
err
:=
w
.
loadKeyPair
()
if
err
!=
nil
{
return
nil
,
err
}
certSerial
:=
w
.
config
[
"certSerial"
]
verifier
:=
verifiers
.
NewSHA256WithRSAPubkeyVerifier
(
w
.
config
[
"publicKeyId"
],
*
publicKey
)
client
,
err
:=
core
.
NewClient
(
context
.
Background
(),
option
.
WithMerchantCredential
(
w
.
config
[
"mchId"
],
certSerial
,
privateKey
),
option
.
WithVerifier
(
verifier
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay init client: %w"
,
err
)
}
handler
,
err
:=
notify
.
NewRSANotifyHandler
(
w
.
config
[
"apiV3Key"
],
verifier
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay init notify handler: %w"
,
err
)
}
w
.
notifyHandler
=
handler
w
.
coreClient
=
client
return
w
.
coreClient
,
nil
}
func
(
w
*
Wxpay
)
loadKeyPair
()
(
*
rsa
.
PrivateKey
,
*
rsa
.
PublicKey
,
error
)
{
privateKey
,
err
:=
utils
.
LoadPrivateKey
(
formatPEM
(
w
.
config
[
"privateKey"
],
"PRIVATE KEY"
))
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"wxpay load private key: %w"
,
err
)
}
publicKey
,
err
:=
utils
.
LoadPublicKey
(
formatPEM
(
w
.
config
[
"publicKey"
],
"PUBLIC KEY"
))
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"wxpay load public key: %w"
,
err
)
}
return
privateKey
,
publicKey
,
nil
}
func
(
w
*
Wxpay
)
CreatePayment
(
ctx
context
.
Context
,
req
payment
.
CreatePaymentRequest
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
client
,
err
:=
w
.
ensureClient
()
if
err
!=
nil
{
return
nil
,
err
}
// Request-first, config-fallback (consistent with EasyPay/Alipay)
notifyURL
:=
req
.
NotifyURL
if
notifyURL
==
""
{
notifyURL
=
w
.
config
[
"notifyUrl"
]
}
if
notifyURL
==
""
{
return
nil
,
fmt
.
Errorf
(
"wxpay notifyUrl is required"
)
}
totalFen
,
err
:=
payment
.
YuanToFen
(
req
.
Amount
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay create payment: %w"
,
err
)
}
if
req
.
IsMobile
&&
req
.
ClientIP
!=
""
{
resp
,
err
:=
w
.
createOrder
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
,
true
)
if
err
==
nil
{
return
resp
,
nil
}
if
!
strings
.
Contains
(
err
.
Error
(),
wxpayErrNoAuth
)
{
return
nil
,
err
}
slog
.
Warn
(
"wxpay H5 payment not authorized, falling back to native"
,
"order"
,
req
.
OrderID
)
}
return
w
.
createOrder
(
ctx
,
client
,
req
,
notifyURL
,
totalFen
,
false
)
}
func
(
w
*
Wxpay
)
createOrder
(
ctx
context
.
Context
,
c
*
core
.
Client
,
req
payment
.
CreatePaymentRequest
,
notifyURL
string
,
totalFen
int64
,
useH5
bool
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
if
useH5
{
return
w
.
prepayH5
(
ctx
,
c
,
req
,
notifyURL
,
totalFen
)
}
return
w
.
prepayNative
(
ctx
,
c
,
req
,
notifyURL
,
totalFen
)
}
func
(
w
*
Wxpay
)
prepayNative
(
ctx
context
.
Context
,
c
*
core
.
Client
,
req
payment
.
CreatePaymentRequest
,
notifyURL
string
,
totalFen
int64
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
svc
:=
native
.
NativeApiService
{
Client
:
c
}
cur
:=
wxpayCurrency
resp
,
_
,
err
:=
svc
.
Prepay
(
ctx
,
native
.
PrepayRequest
{
Appid
:
core
.
String
(
w
.
config
[
"appId"
]),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
Description
:
core
.
String
(
req
.
Subject
),
OutTradeNo
:
core
.
String
(
req
.
OrderID
),
NotifyUrl
:
core
.
String
(
notifyURL
),
Amount
:
&
native
.
Amount
{
Total
:
core
.
Int64
(
totalFen
),
Currency
:
&
cur
},
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay native prepay: %w"
,
err
)
}
codeURL
:=
""
if
resp
.
CodeUrl
!=
nil
{
codeURL
=
*
resp
.
CodeUrl
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
req
.
OrderID
,
QRCode
:
codeURL
},
nil
}
func
(
w
*
Wxpay
)
prepayH5
(
ctx
context
.
Context
,
c
*
core
.
Client
,
req
payment
.
CreatePaymentRequest
,
notifyURL
string
,
totalFen
int64
)
(
*
payment
.
CreatePaymentResponse
,
error
)
{
svc
:=
h5
.
H5ApiService
{
Client
:
c
}
cur
:=
wxpayCurrency
tp
:=
wxpayH5Type
resp
,
_
,
err
:=
svc
.
Prepay
(
ctx
,
h5
.
PrepayRequest
{
Appid
:
core
.
String
(
w
.
config
[
"appId"
]),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
Description
:
core
.
String
(
req
.
Subject
),
OutTradeNo
:
core
.
String
(
req
.
OrderID
),
NotifyUrl
:
core
.
String
(
notifyURL
),
Amount
:
&
h5
.
Amount
{
Total
:
core
.
Int64
(
totalFen
),
Currency
:
&
cur
},
SceneInfo
:
&
h5
.
SceneInfo
{
PayerClientIp
:
core
.
String
(
req
.
ClientIP
),
H5Info
:
&
h5
.
H5Info
{
Type
:
&
tp
}},
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay h5 prepay: %w"
,
err
)
}
h5URL
:=
""
if
resp
.
H5Url
!=
nil
{
h5URL
=
*
resp
.
H5Url
}
return
&
payment
.
CreatePaymentResponse
{
TradeNo
:
req
.
OrderID
,
PayURL
:
h5URL
},
nil
}
func
wxSV
(
s
*
string
)
string
{
if
s
==
nil
{
return
""
}
return
*
s
}
func
mapWxState
(
s
string
)
string
{
switch
s
{
case
wxpayTradeStateSuccess
:
return
payment
.
ProviderStatusPaid
case
wxpayTradeStateRefund
:
return
payment
.
ProviderStatusRefunded
case
wxpayTradeStateClosed
,
wxpayTradeStatePayError
:
return
payment
.
ProviderStatusFailed
default
:
return
payment
.
ProviderStatusPending
}
}
func
(
w
*
Wxpay
)
QueryOrder
(
ctx
context
.
Context
,
tradeNo
string
)
(
*
payment
.
QueryOrderResponse
,
error
)
{
c
,
err
:=
w
.
ensureClient
()
if
err
!=
nil
{
return
nil
,
err
}
svc
:=
native
.
NativeApiService
{
Client
:
c
}
tx
,
_
,
err
:=
svc
.
QueryOrderByOutTradeNo
(
ctx
,
native
.
QueryOrderByOutTradeNoRequest
{
OutTradeNo
:
core
.
String
(
tradeNo
),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay query order: %w"
,
err
)
}
var
amt
float64
if
tx
.
Amount
!=
nil
&&
tx
.
Amount
.
Total
!=
nil
{
amt
=
payment
.
FenToYuan
(
*
tx
.
Amount
.
Total
)
}
id
:=
tradeNo
if
tx
.
TransactionId
!=
nil
{
id
=
*
tx
.
TransactionId
}
pa
:=
""
if
tx
.
SuccessTime
!=
nil
{
pa
=
*
tx
.
SuccessTime
}
return
&
payment
.
QueryOrderResponse
{
TradeNo
:
id
,
Status
:
mapWxState
(
wxSV
(
tx
.
TradeState
)),
Amount
:
amt
,
PaidAt
:
pa
},
nil
}
func
(
w
*
Wxpay
)
VerifyNotification
(
ctx
context
.
Context
,
rawBody
string
,
headers
map
[
string
]
string
)
(
*
payment
.
PaymentNotification
,
error
)
{
if
_
,
err
:=
w
.
ensureClient
();
err
!=
nil
{
return
nil
,
err
}
r
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
"/"
,
io
.
NopCloser
(
bytes
.
NewBufferString
(
rawBody
)))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay construct request: %w"
,
err
)
}
for
k
,
v
:=
range
headers
{
r
.
Header
.
Set
(
k
,
v
)
}
var
tx
payments
.
Transaction
nr
,
err
:=
w
.
notifyHandler
.
ParseNotifyRequest
(
ctx
,
r
,
&
tx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay verify notification: %w"
,
err
)
}
if
nr
.
EventType
!=
wxpayEventTransactionSuccess
{
return
nil
,
nil
}
var
amt
float64
if
tx
.
Amount
!=
nil
&&
tx
.
Amount
.
Total
!=
nil
{
amt
=
payment
.
FenToYuan
(
*
tx
.
Amount
.
Total
)
}
st
:=
payment
.
ProviderStatusFailed
if
wxSV
(
tx
.
TradeState
)
==
wxpayTradeStateSuccess
{
st
=
payment
.
ProviderStatusSuccess
}
return
&
payment
.
PaymentNotification
{
TradeNo
:
wxSV
(
tx
.
TransactionId
),
OrderID
:
wxSV
(
tx
.
OutTradeNo
),
Amount
:
amt
,
Status
:
st
,
RawData
:
rawBody
,
},
nil
}
func
(
w
*
Wxpay
)
Refund
(
ctx
context
.
Context
,
req
payment
.
RefundRequest
)
(
*
payment
.
RefundResponse
,
error
)
{
c
,
err
:=
w
.
ensureClient
()
if
err
!=
nil
{
return
nil
,
err
}
rf
,
err
:=
payment
.
YuanToFen
(
req
.
Amount
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay refund amount: %w"
,
err
)
}
tf
,
err
:=
w
.
queryOrderTotalFen
(
ctx
,
c
,
req
.
OrderID
)
if
err
!=
nil
{
return
nil
,
err
}
rs
:=
refunddomestic
.
RefundsApiService
{
Client
:
c
}
cur
:=
wxpayCurrency
res
,
_
,
err
:=
rs
.
Create
(
ctx
,
refunddomestic
.
CreateRequest
{
OutTradeNo
:
core
.
String
(
req
.
OrderID
),
OutRefundNo
:
core
.
String
(
fmt
.
Sprintf
(
"%s-refund-%d"
,
req
.
OrderID
,
time
.
Now
()
.
UnixNano
())),
Reason
:
core
.
String
(
req
.
Reason
),
Amount
:
&
refunddomestic
.
AmountReq
{
Refund
:
core
.
Int64
(
rf
),
Total
:
core
.
Int64
(
tf
),
Currency
:
&
cur
},
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"wxpay refund: %w"
,
err
)
}
rid
:=
wxSV
(
res
.
RefundId
)
if
rid
==
""
{
rid
=
fmt
.
Sprintf
(
"%s-refund"
,
req
.
OrderID
)
}
st
:=
payment
.
ProviderStatusPending
if
res
.
Status
!=
nil
&&
*
res
.
Status
==
refunddomestic
.
STATUS_SUCCESS
{
st
=
payment
.
ProviderStatusSuccess
}
return
&
payment
.
RefundResponse
{
RefundID
:
rid
,
Status
:
st
},
nil
}
func
(
w
*
Wxpay
)
queryOrderTotalFen
(
ctx
context
.
Context
,
c
*
core
.
Client
,
orderID
string
)
(
int64
,
error
)
{
svc
:=
native
.
NativeApiService
{
Client
:
c
}
tx
,
_
,
err
:=
svc
.
QueryOrderByOutTradeNo
(
ctx
,
native
.
QueryOrderByOutTradeNoRequest
{
OutTradeNo
:
core
.
String
(
orderID
),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
})
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"wxpay refund query order: %w"
,
err
)
}
var
tf
int64
if
tx
.
Amount
!=
nil
&&
tx
.
Amount
.
Total
!=
nil
{
tf
=
*
tx
.
Amount
.
Total
}
return
tf
,
nil
}
func
(
w
*
Wxpay
)
CancelPayment
(
ctx
context
.
Context
,
tradeNo
string
)
error
{
c
,
err
:=
w
.
ensureClient
()
if
err
!=
nil
{
return
err
}
svc
:=
native
.
NativeApiService
{
Client
:
c
}
_
,
err
=
svc
.
CloseOrder
(
ctx
,
native
.
CloseOrderRequest
{
OutTradeNo
:
core
.
String
(
tradeNo
),
Mchid
:
core
.
String
(
w
.
config
[
"mchId"
]),
})
if
err
!=
nil
{
return
fmt
.
Errorf
(
"wxpay cancel payment: %w"
,
err
)
}
return
nil
}
var
(
_
payment
.
Provider
=
(
*
Wxpay
)(
nil
)
_
payment
.
CancelableProvider
=
(
*
Wxpay
)(
nil
)
)
backend/internal/payment/provider/wxpay_test.go
0 → 100644
View file @
a04ae28a
//go:build unit
package
provider
import
(
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func
TestMapWxState
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
input
string
want
string
}{
{
name
:
"SUCCESS maps to paid"
,
input
:
wxpayTradeStateSuccess
,
want
:
payment
.
ProviderStatusPaid
,
},
{
name
:
"REFUND maps to refunded"
,
input
:
wxpayTradeStateRefund
,
want
:
payment
.
ProviderStatusRefunded
,
},
{
name
:
"CLOSED maps to failed"
,
input
:
wxpayTradeStateClosed
,
want
:
payment
.
ProviderStatusFailed
,
},
{
name
:
"PAYERROR maps to failed"
,
input
:
wxpayTradeStatePayError
,
want
:
payment
.
ProviderStatusFailed
,
},
{
name
:
"unknown state maps to pending"
,
input
:
"NOTPAY"
,
want
:
payment
.
ProviderStatusPending
,
},
{
name
:
"empty string maps to pending"
,
input
:
""
,
want
:
payment
.
ProviderStatusPending
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
mapWxState
(
tt
.
input
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"mapWxState(%q) = %q, want %q"
,
tt
.
input
,
got
,
tt
.
want
)
}
})
}
}
func
TestWxSV
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
input
*
string
want
string
}{
{
name
:
"nil pointer returns empty string"
,
input
:
nil
,
want
:
""
,
},
{
name
:
"non-nil pointer returns value"
,
input
:
strPtr
(
"hello"
),
want
:
"hello"
,
},
{
name
:
"pointer to empty string returns empty string"
,
input
:
strPtr
(
""
),
want
:
""
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
wxSV
(
tt
.
input
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"wxSV() = %q, want %q"
,
got
,
tt
.
want
)
}
})
}
}
func
strPtr
(
s
string
)
*
string
{
return
&
s
}
func
TestFormatPEM
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
key
string
keyType
string
want
string
}{
{
name
:
"raw key gets wrapped with headers"
,
key
:
"MIIBIjANBgkqhki..."
,
keyType
:
"PUBLIC KEY"
,
want
:
"-----BEGIN PUBLIC KEY-----
\n
MIIBIjANBgkqhki...
\n
-----END PUBLIC KEY-----"
,
},
{
name
:
"already formatted key is returned as-is"
,
key
:
"-----BEGIN PRIVATE KEY-----
\n
MIIEvQIBADANBg...
\n
-----END PRIVATE KEY-----"
,
keyType
:
"PRIVATE KEY"
,
want
:
"-----BEGIN PRIVATE KEY-----
\n
MIIEvQIBADANBg...
\n
-----END PRIVATE KEY-----"
,
},
{
name
:
"key with leading/trailing whitespace is trimmed before check"
,
key
:
"
\n
MIIBIjANBgkqhki...
\n
"
,
keyType
:
"PUBLIC KEY"
,
want
:
"-----BEGIN PUBLIC KEY-----
\n
MIIBIjANBgkqhki...
\n
-----END PUBLIC KEY-----"
,
},
{
name
:
"already formatted key with whitespace is trimmed and returned"
,
key
:
" -----BEGIN RSA PRIVATE KEY-----
\n
data
\n
-----END RSA PRIVATE KEY----- "
,
keyType
:
"RSA PRIVATE KEY"
,
want
:
"-----BEGIN RSA PRIVATE KEY-----
\n
data
\n
-----END RSA PRIVATE KEY-----"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
:=
formatPEM
(
tt
.
key
,
tt
.
keyType
)
if
got
!=
tt
.
want
{
t
.
Errorf
(
"formatPEM(%q, %q) =
\n
%s
\n
want:
\n
%s"
,
tt
.
key
,
tt
.
keyType
,
got
,
tt
.
want
)
}
})
}
}
func
TestNewWxpay
(
t
*
testing
.
T
)
{
t
.
Parallel
()
validConfig
:=
map
[
string
]
string
{
"appId"
:
"wx1234567890"
,
"mchId"
:
"1234567890"
,
"privateKey"
:
"fake-private-key"
,
"apiV3Key"
:
"12345678901234567890123456789012"
,
// exactly 32 bytes
"publicKey"
:
"fake-public-key"
,
"publicKeyId"
:
"key-id-001"
,
"certSerial"
:
"SERIAL001"
,
}
// helper to clone and override config fields
withOverride
:=
func
(
overrides
map
[
string
]
string
)
map
[
string
]
string
{
cfg
:=
make
(
map
[
string
]
string
,
len
(
validConfig
))
for
k
,
v
:=
range
validConfig
{
cfg
[
k
]
=
v
}
for
k
,
v
:=
range
overrides
{
cfg
[
k
]
=
v
}
return
cfg
}
tests
:=
[]
struct
{
name
string
config
map
[
string
]
string
wantErr
bool
errSubstr
string
}{
{
name
:
"valid config succeeds"
,
config
:
validConfig
,
wantErr
:
false
,
},
{
name
:
"missing appId"
,
config
:
withOverride
(
map
[
string
]
string
{
"appId"
:
""
}),
wantErr
:
true
,
errSubstr
:
"appId"
,
},
{
name
:
"missing mchId"
,
config
:
withOverride
(
map
[
string
]
string
{
"mchId"
:
""
}),
wantErr
:
true
,
errSubstr
:
"mchId"
,
},
{
name
:
"missing privateKey"
,
config
:
withOverride
(
map
[
string
]
string
{
"privateKey"
:
""
}),
wantErr
:
true
,
errSubstr
:
"privateKey"
,
},
{
name
:
"missing apiV3Key"
,
config
:
withOverride
(
map
[
string
]
string
{
"apiV3Key"
:
""
}),
wantErr
:
true
,
errSubstr
:
"apiV3Key"
,
},
{
name
:
"missing publicKey"
,
config
:
withOverride
(
map
[
string
]
string
{
"publicKey"
:
""
}),
wantErr
:
true
,
errSubstr
:
"publicKey"
,
},
{
name
:
"missing publicKeyId"
,
config
:
withOverride
(
map
[
string
]
string
{
"publicKeyId"
:
""
}),
wantErr
:
true
,
errSubstr
:
"publicKeyId"
,
},
{
name
:
"apiV3Key too short"
,
config
:
withOverride
(
map
[
string
]
string
{
"apiV3Key"
:
"short"
}),
wantErr
:
true
,
errSubstr
:
"exactly 32 bytes"
,
},
{
name
:
"apiV3Key too long"
,
config
:
withOverride
(
map
[
string
]
string
{
"apiV3Key"
:
"123456789012345678901234567890123"
}),
// 33 bytes
wantErr
:
true
,
errSubstr
:
"exactly 32 bytes"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
got
,
err
:=
NewWxpay
(
"test-instance"
,
tt
.
config
)
if
tt
.
wantErr
{
if
err
==
nil
{
t
.
Fatal
(
"expected error, got nil"
)
}
if
tt
.
errSubstr
!=
""
&&
!
strings
.
Contains
(
err
.
Error
(),
tt
.
errSubstr
)
{
t
.
Errorf
(
"error %q should contain %q"
,
err
.
Error
(),
tt
.
errSubstr
)
}
return
}
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
if
got
==
nil
{
t
.
Fatal
(
"expected non-nil Wxpay instance"
)
}
if
got
.
instanceID
!=
"test-instance"
{
t
.
Errorf
(
"instanceID = %q, want %q"
,
got
.
instanceID
,
"test-instance"
)
}
})
}
}
backend/internal/payment/registry.go
0 → 100644
View file @
a04ae28a
package
payment
import
(
"sync"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// Registry is a thread-safe registry mapping PaymentType to Provider.
type
Registry
struct
{
mu
sync
.
RWMutex
providers
map
[
PaymentType
]
Provider
}
// ErrProviderNotFound is returned when a requested payment provider is not registered.
var
ErrProviderNotFound
=
infraerrors
.
NotFound
(
"PROVIDER_NOT_FOUND"
,
"payment provider not registered"
)
// NewRegistry creates a new empty provider registry.
func
NewRegistry
()
*
Registry
{
return
&
Registry
{
providers
:
make
(
map
[
PaymentType
]
Provider
),
}
}
// Register adds a provider for each of its supported payment types.
// If a type was previously registered, it is overwritten.
func
(
r
*
Registry
)
Register
(
p
Provider
)
{
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
for
_
,
t
:=
range
p
.
SupportedTypes
()
{
r
.
providers
[
t
]
=
p
}
}
// GetProvider returns the provider registered for the given payment type.
func
(
r
*
Registry
)
GetProvider
(
t
PaymentType
)
(
Provider
,
error
)
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
p
,
ok
:=
r
.
providers
[
t
]
if
!
ok
{
return
nil
,
ErrProviderNotFound
}
return
p
,
nil
}
// GetProviderByKey returns the first provider whose ProviderKey matches the given key.
func
(
r
*
Registry
)
GetProviderByKey
(
key
string
)
(
Provider
,
error
)
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
for
_
,
p
:=
range
r
.
providers
{
if
p
.
ProviderKey
()
==
key
{
return
p
,
nil
}
}
return
nil
,
ErrProviderNotFound
}
// GetProviderKey returns the provider key for the given payment type, or empty string if not found.
func
(
r
*
Registry
)
GetProviderKey
(
t
PaymentType
)
string
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
p
,
ok
:=
r
.
providers
[
t
]
if
!
ok
{
return
""
}
return
p
.
ProviderKey
()
}
// SupportedTypes returns all currently registered payment types.
func
(
r
*
Registry
)
SupportedTypes
()
[]
PaymentType
{
r
.
mu
.
RLock
()
defer
r
.
mu
.
RUnlock
()
types
:=
make
([]
PaymentType
,
0
,
len
(
r
.
providers
))
for
t
:=
range
r
.
providers
{
types
=
append
(
types
,
t
)
}
return
types
}
// Clear removes all registered providers.
func
(
r
*
Registry
)
Clear
()
{
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
r
.
providers
=
make
(
map
[
PaymentType
]
Provider
)
}
backend/internal/payment/registry_test.go
0 → 100644
View file @
a04ae28a
package
payment
import
(
"context"
"fmt"
"sync"
"testing"
)
// mockProvider implements the Provider interface for testing.
type
mockProvider
struct
{
name
string
key
string
supportedTypes
[]
PaymentType
}
func
(
m
*
mockProvider
)
Name
()
string
{
return
m
.
name
}
func
(
m
*
mockProvider
)
ProviderKey
()
string
{
return
m
.
key
}
func
(
m
*
mockProvider
)
SupportedTypes
()
[]
PaymentType
{
return
m
.
supportedTypes
}
func
(
m
*
mockProvider
)
CreatePayment
(
_
context
.
Context
,
_
CreatePaymentRequest
)
(
*
CreatePaymentResponse
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockProvider
)
QueryOrder
(
_
context
.
Context
,
_
string
)
(
*
QueryOrderResponse
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockProvider
)
VerifyNotification
(
_
context
.
Context
,
_
string
,
_
map
[
string
]
string
)
(
*
PaymentNotification
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockProvider
)
Refund
(
_
context
.
Context
,
_
RefundRequest
)
(
*
RefundResponse
,
error
)
{
return
nil
,
nil
}
func
TestRegistryRegisterAndGetProvider
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
p
:=
&
mockProvider
{
name
:
"TestPay"
,
key
:
"testpay"
,
supportedTypes
:
[]
PaymentType
{
TypeAlipay
,
TypeWxpay
},
}
r
.
Register
(
p
)
got
,
err
:=
r
.
GetProvider
(
TypeAlipay
)
if
err
!=
nil
{
t
.
Fatalf
(
"GetProvider(alipay) error: %v"
,
err
)
}
if
got
.
ProviderKey
()
!=
"testpay"
{
t
.
Fatalf
(
"GetProvider(alipay) key = %q, want %q"
,
got
.
ProviderKey
(),
"testpay"
)
}
got2
,
err
:=
r
.
GetProvider
(
TypeWxpay
)
if
err
!=
nil
{
t
.
Fatalf
(
"GetProvider(wxpay) error: %v"
,
err
)
}
if
got2
.
ProviderKey
()
!=
"testpay"
{
t
.
Fatalf
(
"GetProvider(wxpay) key = %q, want %q"
,
got2
.
ProviderKey
(),
"testpay"
)
}
}
func
TestRegistryGetProviderNotFound
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
_
,
err
:=
r
.
GetProvider
(
"nonexistent"
)
if
err
==
nil
{
t
.
Fatal
(
"GetProvider for unregistered type should return error"
)
}
}
func
TestRegistryGetProviderByKey
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
p
:=
&
mockProvider
{
name
:
"EasyPay"
,
key
:
"easypay"
,
supportedTypes
:
[]
PaymentType
{
TypeAlipay
},
}
r
.
Register
(
p
)
got
,
err
:=
r
.
GetProviderByKey
(
"easypay"
)
if
err
!=
nil
{
t
.
Fatalf
(
"GetProviderByKey error: %v"
,
err
)
}
if
got
.
Name
()
!=
"EasyPay"
{
t
.
Fatalf
(
"GetProviderByKey name = %q, want %q"
,
got
.
Name
(),
"EasyPay"
)
}
}
func
TestRegistryGetProviderByKeyNotFound
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
_
,
err
:=
r
.
GetProviderByKey
(
"nonexistent"
)
if
err
==
nil
{
t
.
Fatal
(
"GetProviderByKey for unknown key should return error"
)
}
}
func
TestRegistryGetProviderKeyUnknownType
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
key
:=
r
.
GetProviderKey
(
"unknown_type"
)
if
key
!=
""
{
t
.
Fatalf
(
"GetProviderKey for unknown type should return empty, got %q"
,
key
)
}
}
func
TestRegistryGetProviderKeyKnownType
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
p
:=
&
mockProvider
{
name
:
"Stripe"
,
key
:
"stripe"
,
supportedTypes
:
[]
PaymentType
{
TypeStripe
},
}
r
.
Register
(
p
)
key
:=
r
.
GetProviderKey
(
TypeStripe
)
if
key
!=
"stripe"
{
t
.
Fatalf
(
"GetProviderKey(stripe) = %q, want %q"
,
key
,
"stripe"
)
}
}
func
TestRegistrySupportedTypes
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
p1
:=
&
mockProvider
{
name
:
"EasyPay"
,
key
:
"easypay"
,
supportedTypes
:
[]
PaymentType
{
TypeAlipay
,
TypeWxpay
},
}
p2
:=
&
mockProvider
{
name
:
"Stripe"
,
key
:
"stripe"
,
supportedTypes
:
[]
PaymentType
{
TypeStripe
},
}
r
.
Register
(
p1
)
r
.
Register
(
p2
)
types
:=
r
.
SupportedTypes
()
if
len
(
types
)
!=
3
{
t
.
Fatalf
(
"SupportedTypes() len = %d, want 3"
,
len
(
types
))
}
typeSet
:=
make
(
map
[
PaymentType
]
bool
)
for
_
,
tp
:=
range
types
{
typeSet
[
tp
]
=
true
}
for
_
,
expected
:=
range
[]
PaymentType
{
TypeAlipay
,
TypeWxpay
,
TypeStripe
}
{
if
!
typeSet
[
expected
]
{
t
.
Fatalf
(
"SupportedTypes() missing %q"
,
expected
)
}
}
}
func
TestRegistrySupportedTypesEmpty
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
types
:=
r
.
SupportedTypes
()
if
len
(
types
)
!=
0
{
t
.
Fatalf
(
"SupportedTypes() on empty registry should be empty, got %d"
,
len
(
types
))
}
}
func
TestRegistryOverwriteExisting
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
p1
:=
&
mockProvider
{
name
:
"OldPay"
,
key
:
"old"
,
supportedTypes
:
[]
PaymentType
{
TypeAlipay
},
}
p2
:=
&
mockProvider
{
name
:
"NewPay"
,
key
:
"new"
,
supportedTypes
:
[]
PaymentType
{
TypeAlipay
},
}
r
.
Register
(
p1
)
r
.
Register
(
p2
)
got
,
err
:=
r
.
GetProvider
(
TypeAlipay
)
if
err
!=
nil
{
t
.
Fatalf
(
"GetProvider error: %v"
,
err
)
}
if
got
.
Name
()
!=
"NewPay"
{
t
.
Fatalf
(
"expected overwritten provider, got %q"
,
got
.
Name
())
}
}
func
TestRegistryConcurrentAccess
(
t
*
testing
.
T
)
{
t
.
Parallel
()
r
:=
NewRegistry
()
const
goroutines
=
50
var
wg
sync
.
WaitGroup
wg
.
Add
(
goroutines
*
2
)
// Concurrent writers
for
i
:=
0
;
i
<
goroutines
;
i
++
{
go
func
(
idx
int
)
{
defer
wg
.
Done
()
p
:=
&
mockProvider
{
name
:
fmt
.
Sprintf
(
"Provider-%d"
,
idx
),
key
:
fmt
.
Sprintf
(
"key-%d"
,
idx
),
supportedTypes
:
[]
PaymentType
{
PaymentType
(
fmt
.
Sprintf
(
"type-%d"
,
idx
))},
}
r
.
Register
(
p
)
}(
i
)
}
// Concurrent readers
for
i
:=
0
;
i
<
goroutines
;
i
++
{
go
func
()
{
defer
wg
.
Done
()
_
=
r
.
SupportedTypes
()
_
,
_
=
r
.
GetProvider
(
"some-type"
)
_
=
r
.
GetProviderKey
(
"some-type"
)
}()
}
wg
.
Wait
()
types
:=
r
.
SupportedTypes
()
if
len
(
types
)
!=
goroutines
{
t
.
Fatalf
(
"after concurrent registration, expected %d types, got %d"
,
goroutines
,
len
(
types
))
}
}
backend/internal/payment/types.go
0 → 100644
View file @
a04ae28a
// Package payment provides the core payment provider abstraction,
// registry, load balancing, and shared utilities for the payment subsystem.
package
payment
import
"context"
// PaymentType represents a supported payment method.
type
PaymentType
=
string
// Supported payment type constants.
const
(
TypeAlipay
PaymentType
=
"alipay"
TypeWxpay
PaymentType
=
"wxpay"
TypeAlipayDirect
PaymentType
=
"alipay_direct"
TypeWxpayDirect
PaymentType
=
"wxpay_direct"
TypeStripe
PaymentType
=
"stripe"
TypeCard
PaymentType
=
"card"
TypeLink
PaymentType
=
"link"
TypeEasyPay
PaymentType
=
"easypay"
)
// Order status constants shared across payment and service layers.
const
(
OrderStatusPending
=
"PENDING"
OrderStatusPaid
=
"PAID"
OrderStatusRecharging
=
"RECHARGING"
OrderStatusCompleted
=
"COMPLETED"
OrderStatusExpired
=
"EXPIRED"
OrderStatusCancelled
=
"CANCELLED"
OrderStatusFailed
=
"FAILED"
OrderStatusRefundRequested
=
"REFUND_REQUESTED"
OrderStatusRefunding
=
"REFUNDING"
OrderStatusPartiallyRefunded
=
"PARTIALLY_REFUNDED"
OrderStatusRefunded
=
"REFUNDED"
OrderStatusRefundFailed
=
"REFUND_FAILED"
)
// Order types distinguish balance recharges from subscription purchases.
const
(
OrderTypeBalance
=
"balance"
OrderTypeSubscription
=
"subscription"
)
// Entity statuses shared across users, groups, etc.
const
(
EntityStatusActive
=
"active"
)
// Deduction types for refund flow.
const
(
DeductionTypeBalance
=
"balance"
DeductionTypeSubscription
=
"subscription"
DeductionTypeNone
=
"none"
)
// Payment notification status values.
const
(
NotificationStatusSuccess
=
"success"
NotificationStatusPaid
=
"paid"
)
// Provider-level status constants returned by provider implementations
// to the service layer (lowercase, distinct from OrderStatus uppercase constants).
const
(
ProviderStatusPending
=
"pending"
ProviderStatusPaid
=
"paid"
ProviderStatusSuccess
=
"success"
ProviderStatusFailed
=
"failed"
ProviderStatusRefunded
=
"refunded"
)
// DefaultLoadBalanceStrategy is the default load-balancing strategy
// used when no strategy is configured.
const
DefaultLoadBalanceStrategy
=
"round-robin"
// ConfigKeyPublishableKey is the config map key for Stripe's publishable key.
const
ConfigKeyPublishableKey
=
"publishableKey"
// GetBasePaymentType extracts the base payment method from a composite key.
// For example, "alipay_direct" -> "alipay".
func
GetBasePaymentType
(
t
string
)
string
{
switch
{
case
t
==
TypeEasyPay
:
return
TypeEasyPay
case
t
==
TypeStripe
||
t
==
TypeCard
||
t
==
TypeLink
:
return
TypeStripe
case
len
(
t
)
>=
len
(
TypeAlipay
)
&&
t
[
:
len
(
TypeAlipay
)]
==
TypeAlipay
:
return
TypeAlipay
case
len
(
t
)
>=
len
(
TypeWxpay
)
&&
t
[
:
len
(
TypeWxpay
)]
==
TypeWxpay
:
return
TypeWxpay
default
:
return
t
}
}
// CreatePaymentRequest holds the parameters for creating a new payment.
type
CreatePaymentRequest
struct
{
OrderID
string
// Internal order ID
Amount
string
// Pay amount in CNY (formatted to 2 decimal places)
PaymentType
string
// e.g. "alipay", "wxpay", "stripe"
Subject
string
// Product description
NotifyURL
string
// Webhook callback URL
ReturnURL
string
// Browser redirect URL after payment
ClientIP
string
// Payer's IP address
IsMobile
bool
// Whether the request comes from a mobile device
InstanceSubMethods
string
// Comma-separated sub-methods from instance supported_types (for Stripe)
}
// CreatePaymentResponse is returned after successfully initiating a payment.
type
CreatePaymentResponse
struct
{
TradeNo
string
// Third-party transaction ID
PayURL
string
// H5 payment URL (alipay/wxpay)
QRCode
string
// QR code content for scanning
ClientSecret
string
// Stripe PaymentIntent client secret
}
// QueryOrderResponse describes the payment status from the upstream provider.
type
QueryOrderResponse
struct
{
TradeNo
string
Status
string
// "pending", "paid", "failed", "refunded"
Amount
float64
// Amount in CNY
PaidAt
string
// RFC3339 timestamp or empty
}
// PaymentNotification is the parsed result of a webhook/notify callback.
type
PaymentNotification
struct
{
TradeNo
string
OrderID
string
Amount
float64
Status
string
// "success" or "failed"
RawData
string
// Raw notification body for audit
}
// RefundRequest contains the parameters for requesting a refund.
type
RefundRequest
struct
{
TradeNo
string
OrderID
string
Amount
string
// Refund amount formatted to 2 decimal places
Reason
string
}
// RefundResponse is returned after a refund request.
type
RefundResponse
struct
{
RefundID
string
Status
string
// "success", "pending", "failed"
}
// InstanceSelection holds the selected provider instance and its decrypted config.
type
InstanceSelection
struct
{
InstanceID
string
Config
map
[
string
]
string
SupportedTypes
string
// Comma-separated list of supported payment types from the instance
PaymentMode
string
// Payment display mode: "qrcode", "redirect", "popup"
}
// Provider defines the interface that all payment providers must implement.
type
Provider
interface
{
// Name returns a human-readable name for this provider.
Name
()
string
// ProviderKey returns the unique key identifying this provider type (e.g. "easypay").
ProviderKey
()
string
// SupportedTypes returns the list of payment types this provider handles.
SupportedTypes
()
[]
PaymentType
// CreatePayment initiates a payment and returns the upstream response.
CreatePayment
(
ctx
context
.
Context
,
req
CreatePaymentRequest
)
(
*
CreatePaymentResponse
,
error
)
// QueryOrder queries the payment status of the given trade number.
QueryOrder
(
ctx
context
.
Context
,
tradeNo
string
)
(
*
QueryOrderResponse
,
error
)
// VerifyNotification parses and verifies a webhook callback.
// Returns nil for unrecognized or irrelevant events (caller should return 200).
VerifyNotification
(
ctx
context
.
Context
,
rawBody
string
,
headers
map
[
string
]
string
)
(
*
PaymentNotification
,
error
)
// Refund requests a refund from the upstream provider.
Refund
(
ctx
context
.
Context
,
req
RefundRequest
)
(
*
RefundResponse
,
error
)
}
// CancelableProvider extends Provider with the ability to cancel pending payments.
type
CancelableProvider
interface
{
Provider
// CancelPayment cancels/expires a pending payment on the upstream platform.
CancelPayment
(
ctx
context
.
Context
,
tradeNo
string
)
error
}
backend/internal/payment/wire.go
0 → 100644
View file @
a04ae28a
package
payment
import
(
"encoding/hex"
"fmt"
"log/slog"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
)
// EncryptionKey is a named type for the payment encryption key (AES-256, 32 bytes).
// Using a named type avoids Wire ambiguity with other []byte parameters.
type
EncryptionKey
[]
byte
// ProvideEncryptionKey derives the payment encryption key from the TOTP encryption key in config.
// When the key is empty, nil is returned (payment features that need encryption will be disabled).
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
func
ProvideEncryptionKey
(
cfg
*
config
.
Config
)
(
EncryptionKey
,
error
)
{
if
cfg
.
Totp
.
EncryptionKey
==
""
{
slog
.
Warn
(
"payment encryption key not configured — encrypted payment config will be unavailable"
)
return
nil
,
nil
}
key
,
err
:=
hex
.
DecodeString
(
cfg
.
Totp
.
EncryptionKey
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid payment encryption key (hex decode): %w"
,
err
)
}
if
len
(
key
)
!=
32
{
return
nil
,
fmt
.
Errorf
(
"payment encryption key must be 32 bytes, got %d"
,
len
(
key
))
}
return
EncryptionKey
(
key
),
nil
}
// ProvideRegistry creates an empty payment provider registry.
// Providers are registered at runtime after application startup.
func
ProvideRegistry
()
*
Registry
{
return
NewRegistry
()
}
// ProvideDefaultLoadBalancer creates a DefaultLoadBalancer backed by the ent client.
func
ProvideDefaultLoadBalancer
(
client
*
dbent
.
Client
,
key
EncryptionKey
)
*
DefaultLoadBalancer
{
return
NewDefaultLoadBalancer
(
client
,
[]
byte
(
key
))
}
// ProviderSet is the Wire provider set for the payment package.
var
ProviderSet
=
wire
.
NewSet
(
ProvideEncryptionKey
,
ProvideRegistry
,
ProvideDefaultLoadBalancer
,
wire
.
Bind
(
new
(
LoadBalancer
),
new
(
*
DefaultLoadBalancer
)),
)
backend/internal/pkg/apicompat/types.go
View file @
a04ae28a
...
...
@@ -28,7 +28,7 @@ type AnthropicRequest struct {
// AnthropicOutputConfig controls output generation parameters.
type
AnthropicOutputConfig
struct
{
Effort
string
`json:"effort,omitempty"`
// "low" | "medium" | "high"
Effort
string
`json:"effort,omitempty"`
// "low" | "medium" | "high"
| "max"
}
// AnthropicThinking configures extended thinking in the Anthropic API.
...
...
@@ -167,7 +167,7 @@ type ResponsesRequest struct {
// ResponsesReasoning configures reasoning effort in the Responses API.
type
ResponsesReasoning
struct
{
Effort
string
`json:"effort"`
// "low" | "medium" | "high"
Effort
string
`json:"effort"`
// "low" | "medium" | "high"
| "xhigh"
Summary
string
`json:"summary,omitempty"`
// "auto" | "concise" | "detailed"
}
...
...
@@ -345,7 +345,7 @@ type ChatCompletionsRequest struct {
StreamOptions
*
ChatStreamOptions
`json:"stream_options,omitempty"`
Tools
[]
ChatTool
`json:"tools,omitempty"`
ToolChoice
json
.
RawMessage
`json:"tool_choice,omitempty"`
ReasoningEffort
string
`json:"reasoning_effort,omitempty"`
// "low" | "medium" | "high"
ReasoningEffort
string
`json:"reasoning_effort,omitempty"`
// "low" | "medium" | "high"
| "xhigh"
ServiceTier
string
`json:"service_tier,omitempty"`
Stop
json
.
RawMessage
`json:"stop,omitempty"`
// string or []string
...
...
backend/internal/pkg/pagination/pagination.go
View file @
a04ae28a
// Package pagination provides types and helpers for paginated responses.
package
pagination
import
"strings"
const
(
SortOrderAsc
=
"asc"
SortOrderDesc
=
"desc"
)
// PaginationParams 分页参数
type
PaginationParams
struct
{
Page
int
PageSize
int
SortBy
string
SortOrder
string
}
// PaginationResult 分页结果
...
...
@@ -20,6 +29,7 @@ func DefaultPagination() PaginationParams {
return
PaginationParams
{
Page
:
1
,
PageSize
:
20
,
SortOrder
:
SortOrderDesc
,
}
}
...
...
@@ -36,8 +46,32 @@ func (p PaginationParams) Limit() int {
if
p
.
PageSize
<
1
{
return
20
}
if
p
.
PageSize
>
100
{
return
100
if
p
.
PageSize
>
100
0
{
return
100
0
}
return
p
.
PageSize
}
// NormalizeSortOrder normalizes sort order to asc/desc and falls back to defaultOrder.
func
NormalizeSortOrder
(
order
string
,
defaultOrder
string
)
string
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
defaultOrder
))
{
case
SortOrderAsc
:
defaultOrder
=
SortOrderAsc
default
:
defaultOrder
=
SortOrderDesc
}
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
order
))
{
case
SortOrderAsc
:
return
SortOrderAsc
case
SortOrderDesc
:
return
SortOrderDesc
default
:
return
defaultOrder
}
}
// NormalizedSortOrder returns the normalized sort order using defaultOrder as fallback.
func
(
p
PaginationParams
)
NormalizedSortOrder
(
defaultOrder
string
)
string
{
return
NormalizeSortOrder
(
p
.
SortOrder
,
defaultOrder
)
}
backend/internal/pkg/pagination/pagination_test.go
0 → 100644
View file @
a04ae28a
package
pagination
import
"testing"
func
TestNormalizeSortOrder
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
input
string
defaultOrder
string
want
string
}{
{
name
:
"asc"
,
input
:
"asc"
,
defaultOrder
:
"desc"
,
want
:
"asc"
},
{
name
:
"uppercase asc"
,
input
:
"ASC"
,
defaultOrder
:
"desc"
,
want
:
"asc"
},
{
name
:
"desc"
,
input
:
"desc"
,
defaultOrder
:
"asc"
,
want
:
"desc"
},
{
name
:
"trim spaces"
,
input
:
" desc "
,
defaultOrder
:
"asc"
,
want
:
"desc"
},
{
name
:
"invalid falls back"
,
input
:
"sideways"
,
defaultOrder
:
"asc"
,
want
:
"asc"
},
{
name
:
"empty falls back"
,
input
:
""
,
defaultOrder
:
"desc"
,
want
:
"desc"
},
{
name
:
"invalid default falls back to desc"
,
input
:
""
,
defaultOrder
:
"wat"
,
want
:
"desc"
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
if
got
:=
NormalizeSortOrder
(
tt
.
input
,
tt
.
defaultOrder
);
got
!=
tt
.
want
{
t
.
Fatalf
(
"NormalizeSortOrder(%q, %q) = %q, want %q"
,
tt
.
input
,
tt
.
defaultOrder
,
got
,
tt
.
want
)
}
})
}
}
func
TestPaginationParamsNormalizedSortOrder
(
t
*
testing
.
T
)
{
t
.
Parallel
()
params
:=
PaginationParams
{
SortOrder
:
"ASC"
}
if
got
:=
params
.
NormalizedSortOrder
(
"desc"
);
got
!=
"asc"
{
t
.
Fatalf
(
"NormalizedSortOrder = %q, want asc"
,
got
)
}
params
=
PaginationParams
{
SortOrder
:
"bad"
}
if
got
:=
params
.
NormalizedSortOrder
(
"asc"
);
got
!=
"asc"
{
t
.
Fatalf
(
"NormalizedSortOrder invalid fallback = %q, want asc"
,
got
)
}
}
func
TestPaginationParamsLimit
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
pageSize
int
want
int
}{
{
name
:
"non-positive falls back to default"
,
pageSize
:
0
,
want
:
20
},
{
name
:
"negative falls back to default"
,
pageSize
:
-
1
,
want
:
20
},
{
name
:
"normal value keeps"
,
pageSize
:
50
,
want
:
50
},
{
name
:
"max value keeps"
,
pageSize
:
1000
,
want
:
1000
},
{
name
:
"beyond max clamps to 1000"
,
pageSize
:
1500
,
want
:
1000
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
p
:=
PaginationParams
{
PageSize
:
tt
.
pageSize
}
if
got
:=
p
.
Limit
();
got
!=
tt
.
want
{
t
.
Fatalf
(
"Limit() for PageSize=%d = %d, want %d"
,
tt
.
pageSize
,
got
,
tt
.
want
)
}
})
}
}
backend/internal/repository/account_repo.go
View file @
a04ae28a
...
...
@@ -471,21 +471,58 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
case
service
.
StatusActive
:
q
=
q
.
Where
(
dbaccount
.
StatusEQ
(
status
),
dbaccount
.
SchedulableEQ
(
true
),
dbaccount
.
Or
(
dbaccount
.
RateLimitResetAtIsNil
(),
dbaccount
.
RateLimitResetAtLTE
(
time
.
Now
()),
),
dbpredicate
.
Account
(
func
(
s
*
entsql
.
Selector
)
{
col
:=
s
.
C
(
"temp_unschedulable_until"
)
s
.
Where
(
entsql
.
Or
(
entsql
.
IsNull
(
col
),
entsql
.
LTE
(
col
,
entsql
.
Expr
(
"NOW()"
)),
))
}),
)
case
"rate_limited"
:
q
=
q
.
Where
(
dbaccount
.
RateLimitResetAtGT
(
time
.
Now
()))
q
=
q
.
Where
(
dbaccount
.
StatusEQ
(
service
.
StatusActive
),
dbaccount
.
RateLimitResetAtGT
(
time
.
Now
()),
dbpredicate
.
Account
(
func
(
s
*
entsql
.
Selector
)
{
col
:=
s
.
C
(
"temp_unschedulable_until"
)
s
.
Where
(
entsql
.
Or
(
entsql
.
IsNull
(
col
),
entsql
.
LTE
(
col
,
entsql
.
Expr
(
"NOW()"
)),
))
}),
)
case
"temp_unschedulable"
:
q
=
q
.
Where
(
dbpredicate
.
Account
(
func
(
s
*
entsql
.
Selector
)
{
q
=
q
.
Where
(
dbaccount
.
StatusEQ
(
service
.
StatusActive
),
dbpredicate
.
Account
(
func
(
s
*
entsql
.
Selector
)
{
col
:=
s
.
C
(
"temp_unschedulable_until"
)
s
.
Where
(
entsql
.
And
(
entsql
.
Not
(
entsql
.
IsNull
(
col
)),
entsql
.
GT
(
col
,
entsql
.
Expr
(
"NOW()"
)),
))
}))
}),
)
case
"unschedulable"
:
q
=
q
.
Where
(
dbaccount
.
StatusEQ
(
service
.
StatusActive
),
dbaccount
.
SchedulableEQ
(
false
),
dbaccount
.
Or
(
dbaccount
.
RateLimitResetAtIsNil
(),
dbaccount
.
RateLimitResetAtLTE
(
time
.
Now
()),
),
dbpredicate
.
Account
(
func
(
s
*
entsql
.
Selector
)
{
col
:=
s
.
C
(
"temp_unschedulable_until"
)
s
.
Where
(
entsql
.
Or
(
entsql
.
IsNull
(
col
),
entsql
.
LTE
(
col
,
entsql
.
Expr
(
"NOW()"
)),
))
}),
)
default
:
q
=
q
.
Where
(
dbaccount
.
StatusEQ
(
status
))
}
...
...
@@ -518,11 +555,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return
nil
,
nil
,
err
}
accounts
,
er
r
:=
q
.
accounts
Qu
er
y
:=
q
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
dbent
.
Desc
(
dbaccount
.
FieldID
))
.
All
(
ctx
)
Limit
(
params
.
Limit
())
for
_
,
order
:=
range
accountListOrder
(
params
)
{
accountsQuery
=
accountsQuery
.
Order
(
order
)
}
accounts
,
err
:=
accountsQuery
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
...
...
@@ -534,6 +574,50 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return
outAccounts
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
accountListOrder
(
params
pagination
.
PaginationParams
)
[]
func
(
*
entsql
.
Selector
)
{
sortBy
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
params
.
SortBy
))
sortOrder
:=
params
.
NormalizedSortOrder
(
pagination
.
SortOrderAsc
)
field
:=
dbaccount
.
FieldName
defaultOrder
:=
true
switch
sortBy
{
case
""
,
"name"
:
field
=
dbaccount
.
FieldName
case
"id"
:
field
=
dbaccount
.
FieldID
defaultOrder
=
false
case
"status"
:
field
=
dbaccount
.
FieldStatus
defaultOrder
=
false
case
"schedulable"
:
field
=
dbaccount
.
FieldSchedulable
defaultOrder
=
false
case
"priority"
:
field
=
dbaccount
.
FieldPriority
defaultOrder
=
false
case
"rate_multiplier"
:
field
=
dbaccount
.
FieldRateMultiplier
defaultOrder
=
false
case
"last_used_at"
:
field
=
dbaccount
.
FieldLastUsedAt
defaultOrder
=
false
case
"expires_at"
:
field
=
dbaccount
.
FieldExpiresAt
defaultOrder
=
false
case
"created_at"
:
field
=
dbaccount
.
FieldCreatedAt
defaultOrder
=
false
}
if
sortOrder
==
pagination
.
SortOrderDesc
{
return
[]
func
(
*
entsql
.
Selector
){
dbent
.
Desc
(
field
),
dbent
.
Desc
(
dbaccount
.
FieldID
)}
}
if
defaultOrder
{
return
[]
func
(
*
entsql
.
Selector
){
dbent
.
Asc
(
dbaccount
.
FieldName
),
dbent
.
Asc
(
dbaccount
.
FieldID
)}
}
return
[]
func
(
*
entsql
.
Selector
){
dbent
.
Asc
(
field
),
dbent
.
Asc
(
dbaccount
.
FieldID
)}
}
func
(
r
*
accountRepository
)
ListByGroup
(
ctx
context
.
Context
,
groupID
int64
)
([]
service
.
Account
,
error
)
{
accounts
,
err
:=
r
.
queryAccountsByGroup
(
ctx
,
groupID
,
accountGroupQueryOptions
{
status
:
service
.
StatusActive
,
...
...
backend/internal/repository/account_repo_integration_test.go
View file @
a04ae28a
...
...
@@ -256,7 +256,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
},
},
{
name
:
"filter_by_status_active_excludes_r
ate_limited
"
,
name
:
"filter_by_status_active_excludes_r
untime_blocked_accounts
"
,
setup
:
func
(
client
*
dbent
.
Client
)
{
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-normal"
,
Status
:
service
.
StatusActive
})
rateLimited
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-rate-limited"
,
Status
:
service
.
StatusActive
})
...
...
@@ -264,6 +264,16 @@ func (s *AccountRepoSuite) TestListWithFilters() {
SetRateLimitResetAt
(
time
.
Now
()
.
Add
(
10
*
time
.
Minute
))
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
tempUnsched
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-temp-unsched"
,
Status
:
service
.
StatusActive
})
err
=
client
.
Account
.
UpdateOneID
(
tempUnsched
.
ID
)
.
SetTempUnschedulableUntil
(
time
.
Now
()
.
Add
(
15
*
time
.
Minute
))
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
unsched
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-unsched"
,
Status
:
service
.
StatusActive
})
err
=
client
.
Account
.
UpdateOneID
(
unsched
.
ID
)
.
SetSchedulable
(
false
)
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
},
status
:
service
.
StatusActive
,
wantCount
:
1
,
...
...
@@ -271,6 +281,75 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s
.
Require
()
.
Equal
(
"active-normal"
,
accounts
[
0
]
.
Name
)
},
},
{
name
:
"filter_by_status_unschedulable_excludes_rate_limited_and_temp_unschedulable"
,
setup
:
func
(
client
*
dbent
.
Client
)
{
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-normal"
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
})
unsched
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-unsched"
,
Status
:
service
.
StatusActive
})
err
:=
client
.
Account
.
UpdateOneID
(
unsched
.
ID
)
.
SetSchedulable
(
false
)
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
rateLimited
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-rate-limited"
,
Status
:
service
.
StatusActive
})
err
=
client
.
Account
.
UpdateOneID
(
rateLimited
.
ID
)
.
SetSchedulable
(
false
)
.
SetRateLimitResetAt
(
time
.
Now
()
.
Add
(
10
*
time
.
Minute
))
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
tempUnsched
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-temp-unsched"
,
Status
:
service
.
StatusActive
})
err
=
client
.
Account
.
UpdateOneID
(
tempUnsched
.
ID
)
.
SetSchedulable
(
false
)
.
SetTempUnschedulableUntil
(
time
.
Now
()
.
Add
(
15
*
time
.
Minute
))
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
},
status
:
"unschedulable"
,
wantCount
:
1
,
validate
:
func
(
accounts
[]
service
.
Account
)
{
s
.
Require
()
.
Equal
(
"active-unsched"
,
accounts
[
0
]
.
Name
)
},
},
{
name
:
"filter_by_status_rate_limited_excludes_temp_unschedulable"
,
setup
:
func
(
client
*
dbent
.
Client
)
{
rateLimited
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-rate-limited"
,
Status
:
service
.
StatusActive
})
err
:=
client
.
Account
.
UpdateOneID
(
rateLimited
.
ID
)
.
SetRateLimitResetAt
(
time
.
Now
()
.
Add
(
10
*
time
.
Minute
))
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
tempUnsched
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-temp-unsched"
,
Status
:
service
.
StatusActive
})
err
=
client
.
Account
.
UpdateOneID
(
tempUnsched
.
ID
)
.
SetRateLimitResetAt
(
time
.
Now
()
.
Add
(
20
*
time
.
Minute
))
.
SetTempUnschedulableUntil
(
time
.
Now
()
.
Add
(
15
*
time
.
Minute
))
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
},
status
:
"rate_limited"
,
wantCount
:
1
,
validate
:
func
(
accounts
[]
service
.
Account
)
{
s
.
Require
()
.
Equal
(
"active-rate-limited"
,
accounts
[
0
]
.
Name
)
},
},
{
name
:
"filter_by_status_temp_unschedulable_excludes_manually_unschedulable"
,
setup
:
func
(
client
*
dbent
.
Client
)
{
tempUnsched
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-temp-unsched"
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
})
err
:=
client
.
Account
.
UpdateOneID
(
tempUnsched
.
ID
)
.
SetTempUnschedulableUntil
(
time
.
Now
()
.
Add
(
15
*
time
.
Minute
))
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
unsched
:=
mustCreateAccount
(
s
.
T
(),
client
,
&
service
.
Account
{
Name
:
"active-unsched"
,
Status
:
service
.
StatusActive
})
err
=
client
.
Account
.
UpdateOneID
(
unsched
.
ID
)
.
SetSchedulable
(
false
)
.
Exec
(
context
.
Background
())
s
.
Require
()
.
NoError
(
err
)
},
status
:
"temp_unschedulable"
,
wantCount
:
1
,
validate
:
func
(
accounts
[]
service
.
Account
)
{
s
.
Require
()
.
Equal
(
"active-temp-unsched"
,
accounts
[
0
]
.
Name
)
},
},
{
name
:
"filter_by_search"
,
setup
:
func
(
client
*
dbent
.
Client
)
{
...
...
backend/internal/repository/account_repo_sort_integration_test.go
0 → 100644
View file @
a04ae28a
//go:build integration
package
repository
import
(
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func
(
s
*
AccountRepoSuite
)
TestList_DefaultSortByNameAsc
()
{
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"z-account"
})
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"a-account"
})
accounts
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
2
)
s
.
Require
()
.
Equal
(
"a-account"
,
accounts
[
0
]
.
Name
)
s
.
Require
()
.
Equal
(
"z-account"
,
accounts
[
1
]
.
Name
)
}
func
(
s
*
AccountRepoSuite
)
TestListWithFilters_SortByPriorityDesc
()
{
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"low-priority"
,
Priority
:
10
})
mustCreateAccount
(
s
.
T
(),
s
.
client
,
&
service
.
Account
{
Name
:
"high-priority"
,
Priority
:
90
})
accounts
,
_
,
err
:=
s
.
repo
.
ListWithFilters
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
,
SortBy
:
"priority"
,
SortOrder
:
"desc"
,
},
""
,
""
,
""
,
""
,
0
,
""
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
accounts
,
2
)
s
.
Require
()
.
Equal
(
"high-priority"
,
accounts
[
0
]
.
Name
)
s
.
Require
()
.
Equal
(
"low-priority"
,
accounts
[
1
]
.
Name
)
}
backend/internal/repository/announcement_repo.go
View file @
a04ae28a
...
...
@@ -2,12 +2,15 @@ package repository
import
(
"context"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
entsql
"entgo.io/ent/dialect/sql"
)
type
announcementRepository
struct
{
...
...
@@ -128,11 +131,14 @@ func (r *announcementRepository) List(
return
nil
,
nil
,
err
}
items
,
er
r
:=
q
.
items
Qu
er
y
:=
q
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
dbent
.
Desc
(
announcement
.
FieldID
))
.
All
(
ctx
)
Limit
(
params
.
Limit
())
for
_
,
order
:=
range
announcementListOrders
(
params
)
{
itemsQuery
=
itemsQuery
.
Order
(
order
)
}
items
,
err
:=
itemsQuery
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
...
...
@@ -141,6 +147,56 @@ func (r *announcementRepository) List(
return
out
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
announcementListOrder
(
params
pagination
.
PaginationParams
)
(
string
,
string
)
{
sortBy
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
params
.
SortBy
))
sortOrder
:=
params
.
NormalizedSortOrder
(
pagination
.
SortOrderDesc
)
switch
sortBy
{
case
"title"
:
return
announcement
.
FieldTitle
,
sortOrder
case
"status"
:
return
announcement
.
FieldStatus
,
sortOrder
case
"notify_mode"
:
return
announcement
.
FieldNotifyMode
,
sortOrder
case
"starts_at"
:
return
announcement
.
FieldStartsAt
,
sortOrder
case
"ends_at"
:
return
announcement
.
FieldEndsAt
,
sortOrder
case
"id"
:
return
announcement
.
FieldID
,
sortOrder
case
""
,
"created_at"
:
return
announcement
.
FieldCreatedAt
,
sortOrder
default
:
return
announcement
.
FieldCreatedAt
,
pagination
.
SortOrderDesc
}
}
func
announcementListOrders
(
params
pagination
.
PaginationParams
)
[]
func
(
*
entsql
.
Selector
)
{
field
,
sortOrder
:=
announcementListOrder
(
params
)
if
sortOrder
==
pagination
.
SortOrderAsc
{
if
field
==
announcement
.
FieldID
{
return
[]
func
(
*
entsql
.
Selector
){
dbent
.
Asc
(
field
),
}
}
return
[]
func
(
*
entsql
.
Selector
){
dbent
.
Asc
(
field
),
dbent
.
Asc
(
announcement
.
FieldID
),
}
}
if
field
==
announcement
.
FieldID
{
return
[]
func
(
*
entsql
.
Selector
){
dbent
.
Desc
(
field
),
}
}
return
[]
func
(
*
entsql
.
Selector
){
dbent
.
Desc
(
field
),
dbent
.
Desc
(
announcement
.
FieldID
),
}
}
func
(
r
*
announcementRepository
)
ListActive
(
ctx
context
.
Context
,
now
time
.
Time
)
([]
service
.
Announcement
,
error
)
{
q
:=
r
.
client
.
Announcement
.
Query
()
.
Where
(
...
...
backend/internal/repository/announcement_repo_sort_test.go
0 → 100644
View file @
a04ae28a
package
repository
import
(
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
func
TestAnnouncementListOrder
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
params
pagination
.
PaginationParams
wantBy
string
want
string
}{
{
name
:
"default created_at desc"
,
params
:
pagination
.
PaginationParams
{},
wantBy
:
"created_at"
,
want
:
"desc"
,
},
{
name
:
"title asc"
,
params
:
pagination
.
PaginationParams
{
SortBy
:
"title"
,
SortOrder
:
"ASC"
,
},
wantBy
:
"title"
,
want
:
"asc"
,
},
{
name
:
"status desc"
,
params
:
pagination
.
PaginationParams
{
SortBy
:
"status"
,
SortOrder
:
"desc"
,
},
wantBy
:
"status"
,
want
:
"desc"
,
},
{
name
:
"invalid falls back"
,
params
:
pagination
.
PaginationParams
{
SortBy
:
"sideways"
,
SortOrder
:
"wat"
,
},
wantBy
:
"created_at"
,
want
:
"desc"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
gotBy
,
gotOrder
:=
announcementListOrder
(
tt
.
params
)
if
gotBy
!=
tt
.
wantBy
||
gotOrder
!=
tt
.
want
{
t
.
Fatalf
(
"announcementListOrder(%+v) = (%q, %q), want (%q, %q)"
,
tt
.
params
,
gotBy
,
gotOrder
,
tt
.
wantBy
,
tt
.
want
)
}
})
}
}
backend/internal/repository/api_key_repo.go
View file @
a04ae28a
...
...
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
...
...
@@ -14,6 +15,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
entsql
"entgo.io/ent/dialect/sql"
)
type
apiKeyRepository
struct
{
...
...
@@ -164,6 +167,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group
.
FieldSupportedModelScopes
,
group
.
FieldAllowMessagesDispatch
,
group
.
FieldDefaultMappedModel
,
group
.
FieldMessagesDispatchModelConfig
,
)
})
.
Only
(
ctx
)
...
...
@@ -309,12 +313,15 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return
nil
,
nil
,
err
}
keys
,
er
r
:=
q
.
keys
Qu
er
y
:=
q
.
WithGroup
()
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
dbent
.
Desc
(
apikey
.
FieldID
))
.
All
(
ctx
)
Limit
(
params
.
Limit
())
for
_
,
order
:=
range
apiKeyListOrder
(
params
)
{
keysQuery
=
keysQuery
.
Order
(
order
)
}
keys
,
err
:=
keysQuery
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
...
...
@@ -359,12 +366,15 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return
nil
,
nil
,
err
}
keys
,
er
r
:=
q
.
keys
Qu
er
y
:=
q
.
WithUser
()
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Order
(
dbent
.
Desc
(
apikey
.
FieldID
))
.
All
(
ctx
)
Limit
(
params
.
Limit
())
for
_
,
order
:=
range
apiKeyListOrder
(
params
)
{
keysQuery
=
keysQuery
.
Order
(
order
)
}
keys
,
err
:=
keysQuery
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
...
...
@@ -377,6 +387,32 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return
outKeys
,
paginationResultFromTotal
(
int64
(
total
),
params
),
nil
}
func
apiKeyListOrder
(
params
pagination
.
PaginationParams
)
[]
func
(
*
entsql
.
Selector
)
{
sortBy
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
params
.
SortBy
))
sortOrder
:=
params
.
NormalizedSortOrder
(
pagination
.
SortOrderDesc
)
var
field
string
switch
sortBy
{
case
"name"
:
field
=
apikey
.
FieldName
case
"status"
:
field
=
apikey
.
FieldStatus
case
"expires_at"
:
field
=
apikey
.
FieldExpiresAt
case
"last_used_at"
:
field
=
apikey
.
FieldLastUsedAt
case
"created_at"
:
field
=
apikey
.
FieldCreatedAt
default
:
field
=
apikey
.
FieldID
}
if
sortOrder
==
pagination
.
SortOrderAsc
{
return
[]
func
(
*
entsql
.
Selector
){
dbent
.
Asc
(
field
),
dbent
.
Asc
(
apikey
.
FieldID
)}
}
return
[]
func
(
*
entsql
.
Selector
){
dbent
.
Desc
(
field
),
dbent
.
Desc
(
apikey
.
FieldID
)}
}
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func
(
r
*
apiKeyRepository
)
SearchAPIKeys
(
ctx
context
.
Context
,
userID
int64
,
keyword
string
,
limit
int
)
([]
service
.
APIKey
,
error
)
{
q
:=
r
.
activeQuery
()
...
...
@@ -654,6 +690,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequireOAuthOnly
:
g
.
RequireOauthOnly
,
RequirePrivacySet
:
g
.
RequirePrivacySet
,
DefaultMappedModel
:
g
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
g
.
MessagesDispatchModelConfig
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
...
...
backend/internal/repository/api_key_repo_integration_test.go
View file @
a04ae28a
...
...
@@ -86,6 +86,45 @@ func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
s
.
Require
()
.
Error
(
err
,
"expected error for non-existent key"
)
}
func
(
s
*
APIKeyRepoSuite
)
TestGetByKeyForAuth_PreservesMessagesDispatchModelConfig
()
{
user
:=
s
.
mustCreateUser
(
"getbykey-auth-dispatch@test.com"
)
group
,
err
:=
s
.
client
.
Group
.
Create
()
.
SetName
(
"g-auth-dispatch"
)
.
SetPlatform
(
service
.
PlatformOpenAI
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1
)
.
SetAllowMessagesDispatch
(
true
)
.
SetDefaultMappedModel
(
"gpt-5.4"
)
.
SetMessagesDispatchModelConfig
(
service
.
OpenAIMessagesDispatchModelConfig
{
OpusMappedModel
:
"gpt-5.4-nano"
,
SonnetMappedModel
:
"gpt-5.3-codex"
,
HaikuMappedModel
:
"gpt-5.4-mini"
,
ExactModelMappings
:
map
[
string
]
string
{
"claude-sonnet-4.5"
:
"gpt-5.4-nano"
,
},
})
.
Save
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
)
key
:=
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-getbykey-auth-dispatch"
,
Name
:
"Dispatch Key"
,
GroupID
:
&
group
.
ID
,
Status
:
service
.
StatusActive
,
}
s
.
Require
()
.
NoError
(
s
.
repo
.
Create
(
s
.
ctx
,
key
))
got
,
err
:=
s
.
repo
.
GetByKeyForAuth
(
s
.
ctx
,
key
.
Key
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NotNil
(
got
.
Group
)
s
.
Require
()
.
True
(
got
.
Group
.
AllowMessagesDispatch
)
s
.
Require
()
.
Equal
(
"gpt-5.4"
,
got
.
Group
.
DefaultMappedModel
)
s
.
Require
()
.
Equal
(
"gpt-5.4-nano"
,
got
.
Group
.
MessagesDispatchModelConfig
.
OpusMappedModel
)
s
.
Require
()
.
Equal
(
"gpt-5.4-nano"
,
got
.
Group
.
MessagesDispatchModelConfig
.
ExactModelMappings
[
"claude-sonnet-4.5"
])
}
// --- Update ---
func
(
s
*
APIKeyRepoSuite
)
TestUpdate
()
{
...
...
backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go
0 → 100644
View file @
a04ae28a
package
repository
import
(
"context"
"testing"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestGroupEntityToService_PreservesMessagesDispatchModelConfig
(
t
*
testing
.
T
)
{
group
:=
&
dbent
.
Group
{
ID
:
1
,
Name
:
"openai-dispatch"
,
Platform
:
service
.
PlatformOpenAI
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
RateMultiplier
:
1
,
AllowMessagesDispatch
:
true
,
DefaultMappedModel
:
"gpt-5.4"
,
MessagesDispatchModelConfig
:
service
.
OpenAIMessagesDispatchModelConfig
{
OpusMappedModel
:
"gpt-5.4-nano"
,
SonnetMappedModel
:
"gpt-5.3-codex"
,
HaikuMappedModel
:
"gpt-5.4-mini"
,
ExactModelMappings
:
map
[
string
]
string
{
"claude-sonnet-4.5"
:
"gpt-5.4-nano"
,
},
},
}
got
:=
groupEntityToService
(
group
)
require
.
NotNil
(
t
,
got
)
require
.
Equal
(
t
,
group
.
MessagesDispatchModelConfig
,
got
.
MessagesDispatchModelConfig
)
}
func
TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_SQLite
(
t
*
testing
.
T
)
{
repo
,
client
:=
newAPIKeyRepoSQLite
(
t
)
ctx
:=
context
.
Background
()
user
:=
mustCreateAPIKeyRepoUser
(
t
,
ctx
,
client
,
"getbykey-auth-dispatch-unit@test.com"
)
group
,
err
:=
client
.
Group
.
Create
()
.
SetName
(
"g-auth-dispatch-unit"
)
.
SetPlatform
(
service
.
PlatformOpenAI
)
.
SetStatus
(
service
.
StatusActive
)
.
SetSubscriptionType
(
service
.
SubscriptionTypeStandard
)
.
SetRateMultiplier
(
1
)
.
SetAllowMessagesDispatch
(
true
)
.
SetDefaultMappedModel
(
"gpt-5.4"
)
.
SetMessagesDispatchModelConfig
(
service
.
OpenAIMessagesDispatchModelConfig
{
OpusMappedModel
:
"gpt-5.4-nano"
,
SonnetMappedModel
:
"gpt-5.3-codex"
,
HaikuMappedModel
:
"gpt-5.4-mini"
,
ExactModelMappings
:
map
[
string
]
string
{
"claude-sonnet-4.5"
:
"gpt-5.4-nano"
,
},
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
key
:=
&
service
.
APIKey
{
UserID
:
user
.
ID
,
Key
:
"sk-getbykey-auth-dispatch-unit"
,
Name
:
"Dispatch Key Unit"
,
GroupID
:
&
group
.
ID
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
key
))
got
,
err
:=
repo
.
GetByKeyForAuth
(
ctx
,
key
.
Key
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
got
.
Group
)
require
.
Equal
(
t
,
group
.
MessagesDispatchModelConfig
,
got
.
Group
.
MessagesDispatchModelConfig
)
}
backend/internal/repository/api_key_repo_sort_integration_test.go
0 → 100644
View file @
a04ae28a
//go:build integration
package
repository
import
(
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func
(
s
*
APIKeyRepoSuite
)
TestListByUserID_SortByNameAsc
()
{
user
:=
s
.
mustCreateUser
(
"sort-name@example.com"
)
s
.
mustCreateApiKey
(
user
.
ID
,
"sk-z"
,
"z-key"
,
nil
)
s
.
mustCreateApiKey
(
user
.
ID
,
"sk-a"
,
"a-key"
,
nil
)
keys
,
_
,
err
:=
s
.
repo
.
ListByUserID
(
s
.
ctx
,
user
.
ID
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
10
,
SortBy
:
"name"
,
SortOrder
:
"asc"
,
},
service
.
APIKeyListFilters
{})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
keys
,
2
)
s
.
Require
()
.
Equal
(
"a-key"
,
keys
[
0
]
.
Name
)
s
.
Require
()
.
Equal
(
"z-key"
,
keys
[
1
]
.
Name
)
}
backend/internal/repository/channel_repo.go
View file @
a04ae28a
...
...
@@ -188,8 +188,8 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery
:=
fmt
.
Sprintf
(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY
c.id ASC
LIMIT $%d OFFSET $%d`
,
whereClause
,
argIdx
,
argIdx
+
1
,
FROM channels c WHERE %s ORDER BY
%s
LIMIT $%d OFFSET $%d`
,
whereClause
,
channelListOrderBy
(
params
),
argIdx
,
argIdx
+
1
,
)
args
=
append
(
args
,
pageSize
,
offset
)
...
...
@@ -246,6 +246,31 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
return
channels
,
paginationResult
,
nil
}
func
channelListOrderBy
(
params
pagination
.
PaginationParams
)
string
{
sortBy
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
params
.
SortBy
))
sortOrder
:=
strings
.
ToUpper
(
params
.
NormalizedSortOrder
(
pagination
.
SortOrderAsc
))
var
column
string
switch
sortBy
{
case
""
:
column
=
"c.id"
sortOrder
=
"ASC"
case
"id"
:
column
=
"c.id"
case
"name"
:
column
=
"c.name"
case
"status"
:
column
=
"c.status"
case
"created_at"
:
column
=
"c.created_at"
default
:
column
=
"c.id"
sortOrder
=
"ASC"
}
return
fmt
.
Sprintf
(
"%s %s, c.id %s"
,
column
,
sortOrder
,
sortOrder
)
}
func
(
r
*
channelRepository
)
ListAll
(
ctx
context
.
Context
)
([]
service
.
Channel
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`
,
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment