Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
1a641392
Commit
1a641392
authored
Jan 10, 2026
by
cyhhao
Browse files
Merge up/main
parents
36b817d0
24d19a5f
Changes
174
Show whitespace changes
Inline
Side-by-side
backend/ent/userattributevalue_query.go
View file @
1a641392
...
...
@@ -8,6 +8,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
...
...
@@ -26,6 +27,7 @@ type UserAttributeValueQuery struct {
predicates
[]
predicate
.
UserAttributeValue
withUser
*
UserQuery
withDefinition
*
UserAttributeDefinitionQuery
modifiers
[]
func
(
*
sql
.
Selector
)
// intermediate query (i.e. traversal path).
sql
*
sql
.
Selector
path
func
(
context
.
Context
)
(
*
sql
.
Selector
,
error
)
...
...
@@ -420,6 +422,9 @@ func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHoo
node
.
Edges
.
loadedTypes
=
loadedTypes
return
node
.
assignValues
(
columns
,
values
)
}
if
len
(
_q
.
modifiers
)
>
0
{
_spec
.
Modifiers
=
_q
.
modifiers
}
for
i
:=
range
hooks
{
hooks
[
i
](
ctx
,
_spec
)
}
...
...
@@ -505,6 +510,9 @@ func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *Us
func
(
_q
*
UserAttributeValueQuery
)
sqlCount
(
ctx
context
.
Context
)
(
int
,
error
)
{
_spec
:=
_q
.
querySpec
()
if
len
(
_q
.
modifiers
)
>
0
{
_spec
.
Modifiers
=
_q
.
modifiers
}
_spec
.
Node
.
Columns
=
_q
.
ctx
.
Fields
if
len
(
_q
.
ctx
.
Fields
)
>
0
{
_spec
.
Unique
=
_q
.
ctx
.
Unique
!=
nil
&&
*
_q
.
ctx
.
Unique
...
...
@@ -573,6 +581,9 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
if
_q
.
ctx
.
Unique
!=
nil
&&
*
_q
.
ctx
.
Unique
{
selector
.
Distinct
()
}
for
_
,
m
:=
range
_q
.
modifiers
{
m
(
selector
)
}
for
_
,
p
:=
range
_q
.
predicates
{
p
(
selector
)
}
...
...
@@ -590,6 +601,32 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
return
selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func
(
_q
*
UserAttributeValueQuery
)
ForUpdate
(
opts
...
sql
.
LockOption
)
*
UserAttributeValueQuery
{
if
_q
.
driver
.
Dialect
()
==
dialect
.
Postgres
{
_q
.
Unique
(
false
)
}
_q
.
modifiers
=
append
(
_q
.
modifiers
,
func
(
s
*
sql
.
Selector
)
{
s
.
ForUpdate
(
opts
...
)
})
return
_q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func
(
_q
*
UserAttributeValueQuery
)
ForShare
(
opts
...
sql
.
LockOption
)
*
UserAttributeValueQuery
{
if
_q
.
driver
.
Dialect
()
==
dialect
.
Postgres
{
_q
.
Unique
(
false
)
}
_q
.
modifiers
=
append
(
_q
.
modifiers
,
func
(
s
*
sql
.
Selector
)
{
s
.
ForShare
(
opts
...
)
})
return
_q
}
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
type
UserAttributeValueGroupBy
struct
{
selector
...
...
backend/ent/usersubscription_query.go
View file @
1a641392
...
...
@@ -9,6 +9,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
...
...
@@ -30,6 +31,7 @@ type UserSubscriptionQuery struct {
withGroup
*
GroupQuery
withAssignedByUser
*
UserQuery
withUsageLogs
*
UsageLogQuery
modifiers
[]
func
(
*
sql
.
Selector
)
// intermediate query (i.e. traversal path).
sql
*
sql
.
Selector
path
func
(
context
.
Context
)
(
*
sql
.
Selector
,
error
)
...
...
@@ -494,6 +496,9 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node
.
Edges
.
loadedTypes
=
loadedTypes
return
node
.
assignValues
(
columns
,
values
)
}
if
len
(
_q
.
modifiers
)
>
0
{
_spec
.
Modifiers
=
_q
.
modifiers
}
for
i
:=
range
hooks
{
hooks
[
i
](
ctx
,
_spec
)
}
...
...
@@ -657,6 +662,9 @@ func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *Usage
func
(
_q
*
UserSubscriptionQuery
)
sqlCount
(
ctx
context
.
Context
)
(
int
,
error
)
{
_spec
:=
_q
.
querySpec
()
if
len
(
_q
.
modifiers
)
>
0
{
_spec
.
Modifiers
=
_q
.
modifiers
}
_spec
.
Node
.
Columns
=
_q
.
ctx
.
Fields
if
len
(
_q
.
ctx
.
Fields
)
>
0
{
_spec
.
Unique
=
_q
.
ctx
.
Unique
!=
nil
&&
*
_q
.
ctx
.
Unique
...
...
@@ -728,6 +736,9 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
if
_q
.
ctx
.
Unique
!=
nil
&&
*
_q
.
ctx
.
Unique
{
selector
.
Distinct
()
}
for
_
,
m
:=
range
_q
.
modifiers
{
m
(
selector
)
}
for
_
,
p
:=
range
_q
.
predicates
{
p
(
selector
)
}
...
...
@@ -745,6 +756,32 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
return
selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func
(
_q
*
UserSubscriptionQuery
)
ForUpdate
(
opts
...
sql
.
LockOption
)
*
UserSubscriptionQuery
{
if
_q
.
driver
.
Dialect
()
==
dialect
.
Postgres
{
_q
.
Unique
(
false
)
}
_q
.
modifiers
=
append
(
_q
.
modifiers
,
func
(
s
*
sql
.
Selector
)
{
s
.
ForUpdate
(
opts
...
)
})
return
_q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func
(
_q
*
UserSubscriptionQuery
)
ForShare
(
opts
...
sql
.
LockOption
)
*
UserSubscriptionQuery
{
if
_q
.
driver
.
Dialect
()
==
dialect
.
Postgres
{
_q
.
Unique
(
false
)
}
_q
.
modifiers
=
append
(
_q
.
modifiers
,
func
(
s
*
sql
.
Selector
)
{
s
.
ForShare
(
opts
...
)
})
return
_q
}
// UserSubscriptionGroupBy is the group-by builder for UserSubscription entities.
type
UserSubscriptionGroupBy
struct
{
selector
...
...
backend/internal/config/config.go
View file @
1a641392
...
...
@@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"log"
"net/url"
"os"
"strings"
"time"
...
...
@@ -43,6 +44,7 @@ type Config struct {
Database
DatabaseConfig
`mapstructure:"database"`
Redis
RedisConfig
`mapstructure:"redis"`
JWT
JWTConfig
`mapstructure:"jwt"`
LinuxDo
LinuxDoConnectConfig
`mapstructure:"linuxdo_connect"`
Default
DefaultConfig
`mapstructure:"default"`
RateLimit
RateLimitConfig
`mapstructure:"rate_limit"`
Pricing
PricingConfig
`mapstructure:"pricing"`
...
...
@@ -322,6 +324,30 @@ type TurnstileConfig struct {
Required
bool
`mapstructure:"required"`
}
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
//
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
// 这里是用于登录 Sub2API 本身的用户体系。
type
LinuxDoConnectConfig
struct
{
Enabled
bool
`mapstructure:"enabled"`
ClientID
string
`mapstructure:"client_id"`
ClientSecret
string
`mapstructure:"client_secret"`
AuthorizeURL
string
`mapstructure:"authorize_url"`
TokenURL
string
`mapstructure:"token_url"`
UserInfoURL
string
`mapstructure:"userinfo_url"`
Scopes
string
`mapstructure:"scopes"`
RedirectURL
string
`mapstructure:"redirect_url"`
// 后端回调地址(需在提供方后台登记)
FrontendRedirectURL
string
`mapstructure:"frontend_redirect_url"`
// 前端接收 token 的路由(默认:/auth/linuxdo/callback)
TokenAuthMethod
string
`mapstructure:"token_auth_method"`
// client_secret_post / client_secret_basic / none
UsePKCE
bool
`mapstructure:"use_pkce"`
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath
string
`mapstructure:"userinfo_email_path"`
UserInfoIDPath
string
`mapstructure:"userinfo_id_path"`
UserInfoUsernamePath
string
`mapstructure:"userinfo_username_path"`
}
type
DefaultConfig
struct
{
AdminEmail
string
`mapstructure:"admin_email"`
AdminPassword
string
`mapstructure:"admin_password"`
...
...
@@ -388,6 +414,18 @@ func Load() (*Config, error) {
cfg
.
Server
.
Mode
=
"debug"
}
cfg
.
JWT
.
Secret
=
strings
.
TrimSpace
(
cfg
.
JWT
.
Secret
)
cfg
.
LinuxDo
.
ClientID
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
ClientID
)
cfg
.
LinuxDo
.
ClientSecret
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
ClientSecret
)
cfg
.
LinuxDo
.
AuthorizeURL
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
AuthorizeURL
)
cfg
.
LinuxDo
.
TokenURL
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
TokenURL
)
cfg
.
LinuxDo
.
UserInfoURL
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoURL
)
cfg
.
LinuxDo
.
Scopes
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
Scopes
)
cfg
.
LinuxDo
.
RedirectURL
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
RedirectURL
)
cfg
.
LinuxDo
.
FrontendRedirectURL
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
FrontendRedirectURL
)
cfg
.
LinuxDo
.
TokenAuthMethod
=
strings
.
ToLower
(
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
TokenAuthMethod
))
cfg
.
LinuxDo
.
UserInfoEmailPath
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoEmailPath
)
cfg
.
LinuxDo
.
UserInfoIDPath
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoIDPath
)
cfg
.
LinuxDo
.
UserInfoUsernamePath
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoUsernamePath
)
cfg
.
CORS
.
AllowedOrigins
=
normalizeStringSlice
(
cfg
.
CORS
.
AllowedOrigins
)
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
=
normalizeStringSlice
(
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
)
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
=
normalizeStringSlice
(
cfg
.
Security
.
ResponseHeaders
.
ForceRemove
)
...
...
@@ -426,6 +464,81 @@ func Load() (*Config, error) {
return
&
cfg
,
nil
}
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
func
ValidateAbsoluteHTTPURL
(
raw
string
)
error
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
fmt
.
Errorf
(
"empty url"
)
}
u
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
err
}
if
!
u
.
IsAbs
()
{
return
fmt
.
Errorf
(
"must be absolute"
)
}
if
!
isHTTPScheme
(
u
.
Scheme
)
{
return
fmt
.
Errorf
(
"unsupported scheme: %s"
,
u
.
Scheme
)
}
if
strings
.
TrimSpace
(
u
.
Host
)
==
""
{
return
fmt
.
Errorf
(
"missing host"
)
}
if
u
.
Fragment
!=
""
{
return
fmt
.
Errorf
(
"must not include fragment"
)
}
return
nil
}
// ValidateFrontendRedirectURL 校验前端回调地址:
// - 允许同源相对路径(以 / 开头)
// - 或绝对 http(s) URL(禁止 fragment)
func
ValidateFrontendRedirectURL
(
raw
string
)
error
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
fmt
.
Errorf
(
"empty url"
)
}
if
strings
.
ContainsAny
(
raw
,
"
\r\n
"
)
{
return
fmt
.
Errorf
(
"contains invalid characters"
)
}
if
strings
.
HasPrefix
(
raw
,
"/"
)
{
if
strings
.
HasPrefix
(
raw
,
"//"
)
{
return
fmt
.
Errorf
(
"must not start with //"
)
}
return
nil
}
u
,
err
:=
url
.
Parse
(
raw
)
if
err
!=
nil
{
return
err
}
if
!
u
.
IsAbs
()
{
return
fmt
.
Errorf
(
"must be absolute http(s) url or relative path"
)
}
if
!
isHTTPScheme
(
u
.
Scheme
)
{
return
fmt
.
Errorf
(
"unsupported scheme: %s"
,
u
.
Scheme
)
}
if
strings
.
TrimSpace
(
u
.
Host
)
==
""
{
return
fmt
.
Errorf
(
"missing host"
)
}
if
u
.
Fragment
!=
""
{
return
fmt
.
Errorf
(
"must not include fragment"
)
}
return
nil
}
func
isHTTPScheme
(
scheme
string
)
bool
{
return
strings
.
EqualFold
(
scheme
,
"http"
)
||
strings
.
EqualFold
(
scheme
,
"https"
)
}
func
warnIfInsecureURL
(
field
,
raw
string
)
{
u
,
err
:=
url
.
Parse
(
strings
.
TrimSpace
(
raw
))
if
err
!=
nil
{
return
}
if
strings
.
EqualFold
(
u
.
Scheme
,
"http"
)
{
log
.
Printf
(
"Warning: %s uses http scheme; use https in production to avoid token leakage."
,
field
)
}
}
func
setDefaults
()
{
viper
.
SetDefault
(
"run_mode"
,
RunModeStandard
)
...
...
@@ -475,6 +588,22 @@ func setDefaults() {
// Turnstile
viper
.
SetDefault
(
"turnstile.required"
,
false
)
// LinuxDo Connect OAuth 登录(终端用户 SSO)
viper
.
SetDefault
(
"linuxdo_connect.enabled"
,
false
)
viper
.
SetDefault
(
"linuxdo_connect.client_id"
,
""
)
viper
.
SetDefault
(
"linuxdo_connect.client_secret"
,
""
)
viper
.
SetDefault
(
"linuxdo_connect.authorize_url"
,
"https://connect.linux.do/oauth2/authorize"
)
viper
.
SetDefault
(
"linuxdo_connect.token_url"
,
"https://connect.linux.do/oauth2/token"
)
viper
.
SetDefault
(
"linuxdo_connect.userinfo_url"
,
"https://connect.linux.do/api/user"
)
viper
.
SetDefault
(
"linuxdo_connect.scopes"
,
"user"
)
viper
.
SetDefault
(
"linuxdo_connect.redirect_url"
,
""
)
viper
.
SetDefault
(
"linuxdo_connect.frontend_redirect_url"
,
"/auth/linuxdo/callback"
)
viper
.
SetDefault
(
"linuxdo_connect.token_auth_method"
,
"client_secret_post"
)
viper
.
SetDefault
(
"linuxdo_connect.use_pkce"
,
false
)
viper
.
SetDefault
(
"linuxdo_connect.userinfo_email_path"
,
""
)
viper
.
SetDefault
(
"linuxdo_connect.userinfo_id_path"
,
""
)
viper
.
SetDefault
(
"linuxdo_connect.userinfo_username_path"
,
""
)
// Database
viper
.
SetDefault
(
"database.host"
,
"localhost"
)
viper
.
SetDefault
(
"database.port"
,
5432
)
...
...
@@ -544,7 +673,7 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.concurrency_slot_ttl_minutes"
,
30
)
// 并发槽位过期时间(支持超长请求)
viper
.
SetDefault
(
"gateway.stream_data_interval_timeout"
,
180
)
viper
.
SetDefault
(
"gateway.stream_keepalive_interval"
,
10
)
viper
.
SetDefault
(
"gateway.max_line_size"
,
1
0
*
1024
*
1024
)
viper
.
SetDefault
(
"gateway.max_line_size"
,
4
0
*
1024
*
1024
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_max_waiting"
,
3
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_wait_timeout"
,
45
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_wait_timeout"
,
30
*
time
.
Second
)
...
...
@@ -586,6 +715,60 @@ func (c *Config) Validate() error {
if
c
.
Security
.
CSP
.
Enabled
&&
strings
.
TrimSpace
(
c
.
Security
.
CSP
.
Policy
)
==
""
{
return
fmt
.
Errorf
(
"security.csp.policy is required when CSP is enabled"
)
}
if
c
.
LinuxDo
.
Enabled
{
if
strings
.
TrimSpace
(
c
.
LinuxDo
.
ClientID
)
==
""
{
return
fmt
.
Errorf
(
"linuxdo_connect.client_id is required when linuxdo_connect.enabled=true"
)
}
if
strings
.
TrimSpace
(
c
.
LinuxDo
.
AuthorizeURL
)
==
""
{
return
fmt
.
Errorf
(
"linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true"
)
}
if
strings
.
TrimSpace
(
c
.
LinuxDo
.
TokenURL
)
==
""
{
return
fmt
.
Errorf
(
"linuxdo_connect.token_url is required when linuxdo_connect.enabled=true"
)
}
if
strings
.
TrimSpace
(
c
.
LinuxDo
.
UserInfoURL
)
==
""
{
return
fmt
.
Errorf
(
"linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true"
)
}
if
strings
.
TrimSpace
(
c
.
LinuxDo
.
RedirectURL
)
==
""
{
return
fmt
.
Errorf
(
"linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true"
)
}
method
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
LinuxDo
.
TokenAuthMethod
))
switch
method
{
case
""
,
"client_secret_post"
,
"client_secret_basic"
,
"none"
:
default
:
return
fmt
.
Errorf
(
"linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none"
)
}
if
method
==
"none"
&&
!
c
.
LinuxDo
.
UsePKCE
{
return
fmt
.
Errorf
(
"linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none"
)
}
if
(
method
==
""
||
method
==
"client_secret_post"
||
method
==
"client_secret_basic"
)
&&
strings
.
TrimSpace
(
c
.
LinuxDo
.
ClientSecret
)
==
""
{
return
fmt
.
Errorf
(
"linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic"
)
}
if
strings
.
TrimSpace
(
c
.
LinuxDo
.
FrontendRedirectURL
)
==
""
{
return
fmt
.
Errorf
(
"linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true"
)
}
if
err
:=
ValidateAbsoluteHTTPURL
(
c
.
LinuxDo
.
AuthorizeURL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"linuxdo_connect.authorize_url invalid: %w"
,
err
)
}
if
err
:=
ValidateAbsoluteHTTPURL
(
c
.
LinuxDo
.
TokenURL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"linuxdo_connect.token_url invalid: %w"
,
err
)
}
if
err
:=
ValidateAbsoluteHTTPURL
(
c
.
LinuxDo
.
UserInfoURL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"linuxdo_connect.userinfo_url invalid: %w"
,
err
)
}
if
err
:=
ValidateAbsoluteHTTPURL
(
c
.
LinuxDo
.
RedirectURL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"linuxdo_connect.redirect_url invalid: %w"
,
err
)
}
if
err
:=
ValidateFrontendRedirectURL
(
c
.
LinuxDo
.
FrontendRedirectURL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"linuxdo_connect.frontend_redirect_url invalid: %w"
,
err
)
}
warnIfInsecureURL
(
"linuxdo_connect.authorize_url"
,
c
.
LinuxDo
.
AuthorizeURL
)
warnIfInsecureURL
(
"linuxdo_connect.token_url"
,
c
.
LinuxDo
.
TokenURL
)
warnIfInsecureURL
(
"linuxdo_connect.userinfo_url"
,
c
.
LinuxDo
.
UserInfoURL
)
warnIfInsecureURL
(
"linuxdo_connect.redirect_url"
,
c
.
LinuxDo
.
RedirectURL
)
warnIfInsecureURL
(
"linuxdo_connect.frontend_redirect_url"
,
c
.
LinuxDo
.
FrontendRedirectURL
)
}
if
c
.
Billing
.
CircuitBreaker
.
Enabled
{
if
c
.
Billing
.
CircuitBreaker
.
FailureThreshold
<=
0
{
return
fmt
.
Errorf
(
"billing.circuit_breaker.failure_threshold must be positive"
)
...
...
backend/internal/config/config_test.go
View file @
1a641392
package
config
import
(
"strings"
"testing"
"time"
...
...
@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
t
.
Fatalf
(
"ResponseHeaders.Enabled = true, want false"
)
}
}
func
TestValidateLinuxDoFrontendRedirectURL
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
cfg
.
LinuxDo
.
Enabled
=
true
cfg
.
LinuxDo
.
ClientID
=
"test-client"
cfg
.
LinuxDo
.
ClientSecret
=
"test-secret"
cfg
.
LinuxDo
.
RedirectURL
=
"https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg
.
LinuxDo
.
TokenAuthMethod
=
"client_secret_post"
cfg
.
LinuxDo
.
UsePKCE
=
false
cfg
.
LinuxDo
.
FrontendRedirectURL
=
"javascript:alert(1)"
err
=
cfg
.
Validate
()
if
err
==
nil
{
t
.
Fatalf
(
"Validate() expected error for javascript scheme, got nil"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"linuxdo_connect.frontend_redirect_url"
)
{
t
.
Fatalf
(
"Validate() expected frontend_redirect_url error, got: %v"
,
err
)
}
}
func
TestValidateLinuxDoPKCERequiredForPublicClient
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
cfg
.
LinuxDo
.
Enabled
=
true
cfg
.
LinuxDo
.
ClientID
=
"test-client"
cfg
.
LinuxDo
.
ClientSecret
=
""
cfg
.
LinuxDo
.
RedirectURL
=
"https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg
.
LinuxDo
.
FrontendRedirectURL
=
"/auth/linuxdo/callback"
cfg
.
LinuxDo
.
TokenAuthMethod
=
"none"
cfg
.
LinuxDo
.
UsePKCE
=
false
err
=
cfg
.
Validate
()
if
err
==
nil
{
t
.
Fatalf
(
"Validate() expected error when token_auth_method=none and use_pkce=false, got nil"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"linuxdo_connect.use_pkce"
)
{
t
.
Fatalf
(
"Validate() expected use_pkce error, got: %v"
,
err
)
}
}
backend/internal/handler/admin/account_handler.go
View file @
1a641392
...
...
@@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct {
Concurrency
*
int
`json:"concurrency"`
Priority
*
int
`json:"priority"`
Status
string
`json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable
*
bool
`json:"schedulable"`
GroupIDs
*
[]
int64
`json:"group_ids"`
Credentials
map
[
string
]
any
`json:"credentials"`
Extra
map
[
string
]
any
`json:"extra"`
...
...
@@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) {
accountType
:=
c
.
Query
(
"type"
)
status
:=
c
.
Query
(
"status"
)
search
:=
c
.
Query
(
"search"
)
// 标准化和验证 search 参数
search
=
strings
.
TrimSpace
(
search
)
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
accounts
,
total
,
err
:=
h
.
adminService
.
ListAccounts
(
c
.
Request
.
Context
(),
page
,
pageSize
,
platform
,
accountType
,
status
,
search
)
if
err
!=
nil
{
...
...
@@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req
.
Concurrency
!=
nil
||
req
.
Priority
!=
nil
||
req
.
Status
!=
""
||
req
.
Schedulable
!=
nil
||
req
.
GroupIDs
!=
nil
||
len
(
req
.
Credentials
)
>
0
||
len
(
req
.
Extra
)
>
0
...
...
@@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
Concurrency
:
req
.
Concurrency
,
Priority
:
req
.
Priority
,
Status
:
req
.
Status
,
Schedulable
:
req
.
Schedulable
,
GroupIDs
:
req
.
GroupIDs
,
Credentials
:
req
.
Credentials
,
Extra
:
req
.
Extra
,
...
...
backend/internal/handler/admin/group_handler.go
View file @
1a641392
...
...
@@ -2,6 +2,7 @@ package admin
import
(
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
...
...
@@ -67,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) {
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
platform
:=
c
.
Query
(
"platform"
)
status
:=
c
.
Query
(
"status"
)
search
:=
c
.
Query
(
"search"
)
// 标准化和验证 search 参数
search
=
strings
.
TrimSpace
(
search
)
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
isExclusiveStr
:=
c
.
Query
(
"is_exclusive"
)
var
isExclusive
*
bool
...
...
@@ -75,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) {
isExclusive
=
&
val
}
groups
,
total
,
err
:=
h
.
adminService
.
ListGroups
(
c
.
Request
.
Context
(),
page
,
pageSize
,
platform
,
status
,
isExclusive
)
groups
,
total
,
err
:=
h
.
adminService
.
ListGroups
(
c
.
Request
.
Context
(),
page
,
pageSize
,
platform
,
status
,
search
,
isExclusive
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
backend/internal/handler/admin/promo_handler.go
0 → 100644
View file @
1a641392
package
admin
import
(
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// PromoHandler handles admin promo code management
type
PromoHandler
struct
{
promoService
*
service
.
PromoService
}
// NewPromoHandler creates a new admin promo handler
func
NewPromoHandler
(
promoService
*
service
.
PromoService
)
*
PromoHandler
{
return
&
PromoHandler
{
promoService
:
promoService
,
}
}
// CreatePromoCodeRequest represents create promo code request
type
CreatePromoCodeRequest
struct
{
Code
string
`json:"code"`
// 可选,为空则自动生成
BonusAmount
float64
`json:"bonus_amount" binding:"required,min=0"`
// 赠送余额
MaxUses
int
`json:"max_uses" binding:"min=0"`
// 最大使用次数,0=无限
ExpiresAt
*
int64
`json:"expires_at"`
// 过期时间戳(秒)
Notes
string
`json:"notes"`
// 备注
}
// UpdatePromoCodeRequest represents update promo code request
type
UpdatePromoCodeRequest
struct
{
Code
*
string
`json:"code"`
BonusAmount
*
float64
`json:"bonus_amount" binding:"omitempty,min=0"`
MaxUses
*
int
`json:"max_uses" binding:"omitempty,min=0"`
Status
*
string
`json:"status" binding:"omitempty,oneof=active disabled"`
ExpiresAt
*
int64
`json:"expires_at"`
Notes
*
string
`json:"notes"`
}
// List handles listing all promo codes with pagination
// GET /api/v1/admin/promo-codes
func
(
h
*
PromoHandler
)
List
(
c
*
gin
.
Context
)
{
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
status
:=
c
.
Query
(
"status"
)
search
:=
strings
.
TrimSpace
(
c
.
Query
(
"search"
))
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
}
codes
,
paginationResult
,
err
:=
h
.
promoService
.
List
(
c
.
Request
.
Context
(),
params
,
status
,
search
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
out
:=
make
([]
dto
.
PromoCode
,
0
,
len
(
codes
))
for
i
:=
range
codes
{
out
=
append
(
out
,
*
dto
.
PromoCodeFromService
(
&
codes
[
i
]))
}
response
.
Paginated
(
c
,
out
,
paginationResult
.
Total
,
page
,
pageSize
)
}
// GetByID handles getting a promo code by ID
// GET /api/v1/admin/promo-codes/:id
func
(
h
*
PromoHandler
)
GetByID
(
c
*
gin
.
Context
)
{
codeID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid promo code ID"
)
return
}
code
,
err
:=
h
.
promoService
.
GetByID
(
c
.
Request
.
Context
(),
codeID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
dto
.
PromoCodeFromService
(
code
))
}
// Create handles creating a new promo code
// POST /api/v1/admin/promo-codes
func
(
h
*
PromoHandler
)
Create
(
c
*
gin
.
Context
)
{
var
req
CreatePromoCodeRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
input
:=
&
service
.
CreatePromoCodeInput
{
Code
:
req
.
Code
,
BonusAmount
:
req
.
BonusAmount
,
MaxUses
:
req
.
MaxUses
,
Notes
:
req
.
Notes
,
}
if
req
.
ExpiresAt
!=
nil
{
t
:=
time
.
Unix
(
*
req
.
ExpiresAt
,
0
)
input
.
ExpiresAt
=
&
t
}
code
,
err
:=
h
.
promoService
.
Create
(
c
.
Request
.
Context
(),
input
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
dto
.
PromoCodeFromService
(
code
))
}
// Update handles updating a promo code
// PUT /api/v1/admin/promo-codes/:id
func
(
h
*
PromoHandler
)
Update
(
c
*
gin
.
Context
)
{
codeID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid promo code ID"
)
return
}
var
req
UpdatePromoCodeRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
input
:=
&
service
.
UpdatePromoCodeInput
{
Code
:
req
.
Code
,
BonusAmount
:
req
.
BonusAmount
,
MaxUses
:
req
.
MaxUses
,
Status
:
req
.
Status
,
Notes
:
req
.
Notes
,
}
if
req
.
ExpiresAt
!=
nil
{
if
*
req
.
ExpiresAt
==
0
{
// 0 表示清除过期时间
input
.
ExpiresAt
=
nil
}
else
{
t
:=
time
.
Unix
(
*
req
.
ExpiresAt
,
0
)
input
.
ExpiresAt
=
&
t
}
}
code
,
err
:=
h
.
promoService
.
Update
(
c
.
Request
.
Context
(),
codeID
,
input
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
dto
.
PromoCodeFromService
(
code
))
}
// Delete handles deleting a promo code
// DELETE /api/v1/admin/promo-codes/:id
func
(
h
*
PromoHandler
)
Delete
(
c
*
gin
.
Context
)
{
codeID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid promo code ID"
)
return
}
err
=
h
.
promoService
.
Delete
(
c
.
Request
.
Context
(),
codeID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Promo code deleted successfully"
})
}
// GetUsages handles getting usage records for a promo code
// GET /api/v1/admin/promo-codes/:id/usages
func
(
h
*
PromoHandler
)
GetUsages
(
c
*
gin
.
Context
)
{
codeID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid promo code ID"
)
return
}
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
}
usages
,
paginationResult
,
err
:=
h
.
promoService
.
ListUsages
(
c
.
Request
.
Context
(),
codeID
,
params
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
out
:=
make
([]
dto
.
PromoCodeUsage
,
0
,
len
(
usages
))
for
i
:=
range
usages
{
out
=
append
(
out
,
*
dto
.
PromoCodeUsageFromService
(
&
usages
[
i
]))
}
response
.
Paginated
(
c
,
out
,
paginationResult
.
Total
,
page
,
pageSize
)
}
backend/internal/handler/admin/proxy_handler.go
View file @
1a641392
...
...
@@ -51,6 +51,11 @@ func (h *ProxyHandler) List(c *gin.Context) {
protocol
:=
c
.
Query
(
"protocol"
)
status
:=
c
.
Query
(
"status"
)
search
:=
c
.
Query
(
"search"
)
// 标准化和验证 search 参数
search
=
strings
.
TrimSpace
(
search
)
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
proxies
,
total
,
err
:=
h
.
adminService
.
ListProxiesWithAccountCount
(
c
.
Request
.
Context
(),
page
,
pageSize
,
protocol
,
status
,
search
)
if
err
!=
nil
{
...
...
backend/internal/handler/admin/redeem_handler.go
View file @
1a641392
...
...
@@ -5,6 +5,7 @@ import (
"encoding/csv"
"fmt"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
...
...
@@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) {
codeType
:=
c
.
Query
(
"type"
)
status
:=
c
.
Query
(
"status"
)
search
:=
c
.
Query
(
"search"
)
// 标准化和验证 search 参数
search
=
strings
.
TrimSpace
(
search
)
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
codes
,
total
,
err
:=
h
.
adminService
.
ListRedeemCodes
(
c
.
Request
.
Context
(),
page
,
pageSize
,
codeType
,
status
,
search
)
if
err
!=
nil
{
...
...
backend/internal/handler/admin/setting_handler.go
View file @
1a641392
...
...
@@ -2,8 +2,10 @@ package admin
import
(
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
...
...
@@ -50,6 +52,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
TurnstileEnabled
:
settings
.
TurnstileEnabled
,
TurnstileSiteKey
:
settings
.
TurnstileSiteKey
,
TurnstileSecretKeyConfigured
:
settings
.
TurnstileSecretKeyConfigured
,
LinuxDoConnectEnabled
:
settings
.
LinuxDoConnectEnabled
,
LinuxDoConnectClientID
:
settings
.
LinuxDoConnectClientID
,
LinuxDoConnectClientSecretConfigured
:
settings
.
LinuxDoConnectClientSecretConfigured
,
LinuxDoConnectRedirectURL
:
settings
.
LinuxDoConnectRedirectURL
,
SiteName
:
settings
.
SiteName
,
SiteLogo
:
settings
.
SiteLogo
,
SiteSubtitle
:
settings
.
SiteSubtitle
,
...
...
@@ -88,6 +94,12 @@ type UpdateSettingsRequest struct {
TurnstileSiteKey
string
`json:"turnstile_site_key"`
TurnstileSecretKey
string
`json:"turnstile_secret_key"`
// LinuxDo Connect OAuth 登录(终端用户 SSO)
LinuxDoConnectEnabled
bool
`json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID
string
`json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecret
string
`json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL
string
`json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
...
...
@@ -165,6 +177,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// LinuxDo Connect 参数验证
if
req
.
LinuxDoConnectEnabled
{
req
.
LinuxDoConnectClientID
=
strings
.
TrimSpace
(
req
.
LinuxDoConnectClientID
)
req
.
LinuxDoConnectClientSecret
=
strings
.
TrimSpace
(
req
.
LinuxDoConnectClientSecret
)
req
.
LinuxDoConnectRedirectURL
=
strings
.
TrimSpace
(
req
.
LinuxDoConnectRedirectURL
)
if
req
.
LinuxDoConnectClientID
==
""
{
response
.
BadRequest
(
c
,
"LinuxDo Client ID is required when enabled"
)
return
}
if
req
.
LinuxDoConnectRedirectURL
==
""
{
response
.
BadRequest
(
c
,
"LinuxDo Redirect URL is required when enabled"
)
return
}
if
err
:=
config
.
ValidateAbsoluteHTTPURL
(
req
.
LinuxDoConnectRedirectURL
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"LinuxDo Redirect URL must be an absolute http(s) URL"
)
return
}
// 如果未提供 client_secret,则保留现有值(如有)。
if
req
.
LinuxDoConnectClientSecret
==
""
{
if
previousSettings
.
LinuxDoConnectClientSecret
==
""
{
response
.
BadRequest
(
c
,
"LinuxDo Client Secret is required when enabled"
)
return
}
req
.
LinuxDoConnectClientSecret
=
previousSettings
.
LinuxDoConnectClientSecret
}
}
settings
:=
&
service
.
SystemSettings
{
RegistrationEnabled
:
req
.
RegistrationEnabled
,
EmailVerifyEnabled
:
req
.
EmailVerifyEnabled
,
...
...
@@ -178,6 +219,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
TurnstileEnabled
:
req
.
TurnstileEnabled
,
TurnstileSiteKey
:
req
.
TurnstileSiteKey
,
TurnstileSecretKey
:
req
.
TurnstileSecretKey
,
LinuxDoConnectEnabled
:
req
.
LinuxDoConnectEnabled
,
LinuxDoConnectClientID
:
req
.
LinuxDoConnectClientID
,
LinuxDoConnectClientSecret
:
req
.
LinuxDoConnectClientSecret
,
LinuxDoConnectRedirectURL
:
req
.
LinuxDoConnectRedirectURL
,
SiteName
:
req
.
SiteName
,
SiteLogo
:
req
.
SiteLogo
,
SiteSubtitle
:
req
.
SiteSubtitle
,
...
...
@@ -222,6 +267,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
TurnstileEnabled
:
updatedSettings
.
TurnstileEnabled
,
TurnstileSiteKey
:
updatedSettings
.
TurnstileSiteKey
,
TurnstileSecretKeyConfigured
:
updatedSettings
.
TurnstileSecretKeyConfigured
,
LinuxDoConnectEnabled
:
updatedSettings
.
LinuxDoConnectEnabled
,
LinuxDoConnectClientID
:
updatedSettings
.
LinuxDoConnectClientID
,
LinuxDoConnectClientSecretConfigured
:
updatedSettings
.
LinuxDoConnectClientSecretConfigured
,
LinuxDoConnectRedirectURL
:
updatedSettings
.
LinuxDoConnectRedirectURL
,
SiteName
:
updatedSettings
.
SiteName
,
SiteLogo
:
updatedSettings
.
SiteLogo
,
SiteSubtitle
:
updatedSettings
.
SiteSubtitle
,
...
...
@@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if
req
.
TurnstileSecretKey
!=
""
{
changed
=
append
(
changed
,
"turnstile_secret_key"
)
}
if
before
.
LinuxDoConnectEnabled
!=
after
.
LinuxDoConnectEnabled
{
changed
=
append
(
changed
,
"linuxdo_connect_enabled"
)
}
if
before
.
LinuxDoConnectClientID
!=
after
.
LinuxDoConnectClientID
{
changed
=
append
(
changed
,
"linuxdo_connect_client_id"
)
}
if
req
.
LinuxDoConnectClientSecret
!=
""
{
changed
=
append
(
changed
,
"linuxdo_connect_client_secret"
)
}
if
before
.
LinuxDoConnectRedirectURL
!=
after
.
LinuxDoConnectRedirectURL
{
changed
=
append
(
changed
,
"linuxdo_connect_redirect_url"
)
}
if
before
.
SiteName
!=
after
.
SiteName
{
changed
=
append
(
changed
,
"site_name"
)
}
...
...
@@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if
before
.
FallbackModelAntigravity
!=
after
.
FallbackModelAntigravity
{
changed
=
append
(
changed
,
"fallback_model_antigravity"
)
}
if
before
.
EnableIdentityPatch
!=
after
.
EnableIdentityPatch
{
changed
=
append
(
changed
,
"enable_identity_patch"
)
}
if
before
.
IdentityPatchPrompt
!=
after
.
IdentityPatchPrompt
{
changed
=
append
(
changed
,
"identity_patch_prompt"
)
}
return
changed
}
...
...
backend/internal/handler/admin/user_handler.go
View file @
1a641392
...
...
@@ -2,6 +2,7 @@ package admin
import
(
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
...
...
@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct {
func
(
h
*
UserHandler
)
List
(
c
*
gin
.
Context
)
{
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
search
:=
c
.
Query
(
"search"
)
// 标准化和验证 search 参数
search
=
strings
.
TrimSpace
(
search
)
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
filters
:=
service
.
UserListFilters
{
Status
:
c
.
Query
(
"status"
),
Role
:
c
.
Query
(
"role"
),
Search
:
c
.
Query
(
"
search
"
)
,
Search
:
search
,
Attributes
:
parseAttributeFilters
(
c
),
}
...
...
backend/internal/handler/api_key_handler.go
View file @
1a641392
...
...
@@ -30,6 +30,8 @@ type CreateAPIKeyRequest struct {
Name
string
`json:"name" binding:"required"`
GroupID
*
int64
`json:"group_id"`
// nullable
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
}
// UpdateAPIKeyRequest represents the update API key request payload
...
...
@@ -37,6 +39,8 @@ type UpdateAPIKeyRequest struct {
Name
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
Status
string
`json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
}
// List handles listing user's API keys with pagination
...
...
@@ -113,6 +117,8 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
Name
:
req
.
Name
,
GroupID
:
req
.
GroupID
,
CustomKey
:
req
.
CustomKey
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
}
key
,
err
:=
h
.
apiKeyService
.
Create
(
c
.
Request
.
Context
(),
subject
.
UserID
,
svcReq
)
if
err
!=
nil
{
...
...
@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
svcReq
:=
service
.
UpdateAPIKeyRequest
{}
svcReq
:=
service
.
UpdateAPIKeyRequest
{
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
}
if
req
.
Name
!=
""
{
svcReq
.
Name
=
&
req
.
Name
}
...
...
backend/internal/handler/auth_handler.go
View file @
1a641392
...
...
@@ -15,14 +15,18 @@ type AuthHandler struct {
cfg
*
config
.
Config
authService
*
service
.
AuthService
userService
*
service
.
UserService
settingSvc
*
service
.
SettingService
promoService
*
service
.
PromoService
}
// NewAuthHandler creates a new AuthHandler
func
NewAuthHandler
(
cfg
*
config
.
Config
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
)
*
AuthHandler
{
func
NewAuthHandler
(
cfg
*
config
.
Config
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
,
settingService
*
service
.
SettingService
,
promoService
*
service
.
PromoService
)
*
AuthHandler
{
return
&
AuthHandler
{
cfg
:
cfg
,
authService
:
authService
,
userService
:
userService
,
settingSvc
:
settingService
,
promoService
:
promoService
,
}
}
...
...
@@ -32,6 +36,7 @@ type RegisterRequest struct {
Password
string
`json:"password" binding:"required,min=6"`
VerifyCode
string
`json:"verify_code"`
TurnstileToken
string
`json:"turnstile_token"`
PromoCode
string
`json:"promo_code"`
// 注册优惠码
}
// SendVerifyCodeRequest 发送验证码请求
...
...
@@ -77,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
}
}
token
,
user
,
err
:=
h
.
authService
.
RegisterWithVerification
(
c
.
Request
.
Context
(),
req
.
Email
,
req
.
Password
,
req
.
VerifyCode
)
token
,
user
,
err
:=
h
.
authService
.
RegisterWithVerification
(
c
.
Request
.
Context
(),
req
.
Email
,
req
.
Password
,
req
.
VerifyCode
,
req
.
PromoCode
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -172,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
response
.
Success
(
c
,
UserResponse
{
User
:
dto
.
UserFromService
(
user
),
RunMode
:
runMode
})
}
// ValidatePromoCodeRequest 验证优惠码请求
type
ValidatePromoCodeRequest
struct
{
Code
string
`json:"code" binding:"required"`
}
// ValidatePromoCodeResponse 验证优惠码响应
type
ValidatePromoCodeResponse
struct
{
Valid
bool
`json:"valid"`
BonusAmount
float64
`json:"bonus_amount,omitempty"`
ErrorCode
string
`json:"error_code,omitempty"`
Message
string
`json:"message,omitempty"`
}
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func
(
h
*
AuthHandler
)
ValidatePromoCode
(
c
*
gin
.
Context
)
{
var
req
ValidatePromoCodeRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
promoCode
,
err
:=
h
.
promoService
.
ValidatePromoCode
(
c
.
Request
.
Context
(),
req
.
Code
)
if
err
!=
nil
{
// 根据错误类型返回对应的错误码
errorCode
:=
"PROMO_CODE_INVALID"
switch
err
{
case
service
.
ErrPromoCodeNotFound
:
errorCode
=
"PROMO_CODE_NOT_FOUND"
case
service
.
ErrPromoCodeExpired
:
errorCode
=
"PROMO_CODE_EXPIRED"
case
service
.
ErrPromoCodeDisabled
:
errorCode
=
"PROMO_CODE_DISABLED"
case
service
.
ErrPromoCodeMaxUsed
:
errorCode
=
"PROMO_CODE_MAX_USED"
case
service
.
ErrPromoCodeAlreadyUsed
:
errorCode
=
"PROMO_CODE_ALREADY_USED"
}
response
.
Success
(
c
,
ValidatePromoCodeResponse
{
Valid
:
false
,
ErrorCode
:
errorCode
,
})
return
}
if
promoCode
==
nil
{
response
.
Success
(
c
,
ValidatePromoCodeResponse
{
Valid
:
false
,
ErrorCode
:
"PROMO_CODE_INVALID"
,
})
return
}
response
.
Success
(
c
,
ValidatePromoCodeResponse
{
Valid
:
true
,
BonusAmount
:
promoCode
.
BonusAmount
,
})
}
backend/internal/handler/auth_linuxdo_oauth.go
0 → 100644
View file @
1a641392
package
handler
import
(
"context"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const
(
linuxDoOAuthCookiePath
=
"/api/v1/auth/oauth/linuxdo"
linuxDoOAuthStateCookieName
=
"linuxdo_oauth_state"
linuxDoOAuthVerifierCookie
=
"linuxdo_oauth_verifier"
linuxDoOAuthRedirectCookie
=
"linuxdo_oauth_redirect"
linuxDoOAuthCookieMaxAgeSec
=
10
*
60
// 10 minutes
linuxDoOAuthDefaultRedirectTo
=
"/dashboard"
linuxDoOAuthDefaultFrontendCB
=
"/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen
=
2048
linuxDoOAuthMaxFragmentValueLen
=
512
linuxDoOAuthMaxSubjectLen
=
64
-
len
(
"linuxdo-"
)
)
type
linuxDoTokenResponse
struct
{
AccessToken
string
`json:"access_token"`
TokenType
string
`json:"token_type"`
ExpiresIn
int64
`json:"expires_in"`
RefreshToken
string
`json:"refresh_token,omitempty"`
Scope
string
`json:"scope,omitempty"`
}
type
linuxDoTokenExchangeError
struct
{
StatusCode
int
ProviderError
string
ProviderDescription
string
Body
string
}
func
(
e
*
linuxDoTokenExchangeError
)
Error
()
string
{
if
e
==
nil
{
return
""
}
parts
:=
[]
string
{
fmt
.
Sprintf
(
"token exchange status=%d"
,
e
.
StatusCode
)}
if
strings
.
TrimSpace
(
e
.
ProviderError
)
!=
""
{
parts
=
append
(
parts
,
"error="
+
strings
.
TrimSpace
(
e
.
ProviderError
))
}
if
strings
.
TrimSpace
(
e
.
ProviderDescription
)
!=
""
{
parts
=
append
(
parts
,
"error_description="
+
strings
.
TrimSpace
(
e
.
ProviderDescription
))
}
return
strings
.
Join
(
parts
,
" "
)
}
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
func
(
h
*
AuthHandler
)
LinuxDoOAuthStart
(
c
*
gin
.
Context
)
{
cfg
,
err
:=
h
.
getLinuxDoOAuthConfig
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
state
,
err
:=
oauth
.
GenerateState
()
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_STATE_GEN_FAILED"
,
"failed to generate oauth state"
)
.
WithCause
(
err
))
return
}
redirectTo
:=
sanitizeFrontendRedirectPath
(
c
.
Query
(
"redirect"
))
if
redirectTo
==
""
{
redirectTo
=
linuxDoOAuthDefaultRedirectTo
}
secureCookie
:=
isRequestHTTPS
(
c
)
setCookie
(
c
,
linuxDoOAuthStateCookieName
,
encodeCookieValue
(
state
),
linuxDoOAuthCookieMaxAgeSec
,
secureCookie
)
setCookie
(
c
,
linuxDoOAuthRedirectCookie
,
encodeCookieValue
(
redirectTo
),
linuxDoOAuthCookieMaxAgeSec
,
secureCookie
)
codeChallenge
:=
""
if
cfg
.
UsePKCE
{
verifier
,
err
:=
oauth
.
GenerateCodeVerifier
()
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_PKCE_GEN_FAILED"
,
"failed to generate pkce verifier"
)
.
WithCause
(
err
))
return
}
codeChallenge
=
oauth
.
GenerateCodeChallenge
(
verifier
)
setCookie
(
c
,
linuxDoOAuthVerifierCookie
,
encodeCookieValue
(
verifier
),
linuxDoOAuthCookieMaxAgeSec
,
secureCookie
)
}
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
if
redirectURI
==
""
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_CONFIG_INVALID"
,
"oauth redirect url not configured"
))
return
}
authURL
,
err
:=
buildLinuxDoAuthorizeURL
(
cfg
,
state
,
codeChallenge
,
redirectURI
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_BUILD_URL_FAILED"
,
"failed to build oauth authorization url"
)
.
WithCause
(
err
))
return
}
c
.
Redirect
(
http
.
StatusFound
,
authURL
)
}
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
func
(
h
*
AuthHandler
)
LinuxDoOAuthCallback
(
c
*
gin
.
Context
)
{
cfg
,
cfgErr
:=
h
.
getLinuxDoOAuthConfig
(
c
.
Request
.
Context
())
if
cfgErr
!=
nil
{
response
.
ErrorFrom
(
c
,
cfgErr
)
return
}
frontendCallback
:=
strings
.
TrimSpace
(
cfg
.
FrontendRedirectURL
)
if
frontendCallback
==
""
{
frontendCallback
=
linuxDoOAuthDefaultFrontendCB
}
if
providerErr
:=
strings
.
TrimSpace
(
c
.
Query
(
"error"
));
providerErr
!=
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"provider_error"
,
providerErr
,
c
.
Query
(
"error_description"
))
return
}
code
:=
strings
.
TrimSpace
(
c
.
Query
(
"code"
))
state
:=
strings
.
TrimSpace
(
c
.
Query
(
"state"
))
if
code
==
""
||
state
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_params"
,
"missing code/state"
,
""
)
return
}
secureCookie
:=
isRequestHTTPS
(
c
)
defer
func
()
{
clearCookie
(
c
,
linuxDoOAuthStateCookieName
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthVerifierCookie
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthRedirectCookie
,
secureCookie
)
}()
expectedState
,
err
:=
readCookieDecoded
(
c
,
linuxDoOAuthStateCookieName
)
if
err
!=
nil
||
expectedState
==
""
||
state
!=
expectedState
{
redirectOAuthError
(
c
,
frontendCallback
,
"invalid_state"
,
"invalid oauth state"
,
""
)
return
}
redirectTo
,
_
:=
readCookieDecoded
(
c
,
linuxDoOAuthRedirectCookie
)
redirectTo
=
sanitizeFrontendRedirectPath
(
redirectTo
)
if
redirectTo
==
""
{
redirectTo
=
linuxDoOAuthDefaultRedirectTo
}
codeVerifier
:=
""
if
cfg
.
UsePKCE
{
codeVerifier
,
_
=
readCookieDecoded
(
c
,
linuxDoOAuthVerifierCookie
)
if
codeVerifier
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_verifier"
,
"missing pkce verifier"
,
""
)
return
}
}
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
if
redirectURI
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"config_error"
,
"oauth redirect url not configured"
,
""
)
return
}
tokenResp
,
err
:=
linuxDoExchangeCode
(
c
.
Request
.
Context
(),
cfg
,
code
,
redirectURI
,
codeVerifier
)
if
err
!=
nil
{
description
:=
""
var
exchangeErr
*
linuxDoTokenExchangeError
if
errors
.
As
(
err
,
&
exchangeErr
)
&&
exchangeErr
!=
nil
{
log
.
Printf
(
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s"
,
exchangeErr
.
StatusCode
,
exchangeErr
.
ProviderError
,
exchangeErr
.
ProviderDescription
,
truncateLogValue
(
exchangeErr
.
Body
,
2048
),
)
description
=
exchangeErr
.
Error
()
}
else
{
log
.
Printf
(
"[LinuxDo OAuth] token exchange failed: %v"
,
err
)
description
=
err
.
Error
()
}
redirectOAuthError
(
c
,
frontendCallback
,
"token_exchange_failed"
,
"failed to exchange oauth code"
,
singleLine
(
description
))
return
}
email
,
username
,
subject
,
err
:=
linuxDoFetchUserInfo
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
)
if
err
!=
nil
{
log
.
Printf
(
"[LinuxDo OAuth] userinfo fetch failed: %v"
,
err
)
redirectOAuthError
(
c
,
frontendCallback
,
"userinfo_failed"
,
"failed to fetch user info"
,
""
)
return
}
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if
subject
!=
""
{
email
=
linuxDoSyntheticEmail
(
subject
)
}
jwtToken
,
_
,
err
:=
h
.
authService
.
LoginOrRegisterOAuth
(
c
.
Request
.
Context
(),
email
,
username
)
if
err
!=
nil
{
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
fragment
:=
url
.
Values
{}
fragment
.
Set
(
"access_token"
,
jwtToken
)
fragment
.
Set
(
"token_type"
,
"Bearer"
)
fragment
.
Set
(
"redirect"
,
redirectTo
)
redirectWithFragment
(
c
,
frontendCallback
,
fragment
)
}
func
(
h
*
AuthHandler
)
getLinuxDoOAuthConfig
(
ctx
context
.
Context
)
(
config
.
LinuxDoConnectConfig
,
error
)
{
if
h
!=
nil
&&
h
.
settingSvc
!=
nil
{
return
h
.
settingSvc
.
GetLinuxDoConnectOAuthConfig
(
ctx
)
}
if
h
==
nil
||
h
.
cfg
==
nil
{
return
config
.
LinuxDoConnectConfig
{},
infraerrors
.
ServiceUnavailable
(
"CONFIG_NOT_READY"
,
"config not loaded"
)
}
if
!
h
.
cfg
.
LinuxDo
.
Enabled
{
return
config
.
LinuxDoConnectConfig
{},
infraerrors
.
NotFound
(
"OAUTH_DISABLED"
,
"oauth login is disabled"
)
}
return
h
.
cfg
.
LinuxDo
,
nil
}
func
linuxDoExchangeCode
(
ctx
context
.
Context
,
cfg
config
.
LinuxDoConnectConfig
,
code
string
,
redirectURI
string
,
codeVerifier
string
,
)
(
*
linuxDoTokenResponse
,
error
)
{
client
:=
req
.
C
()
.
SetTimeout
(
30
*
time
.
Second
)
form
:=
url
.
Values
{}
form
.
Set
(
"grant_type"
,
"authorization_code"
)
form
.
Set
(
"client_id"
,
cfg
.
ClientID
)
form
.
Set
(
"code"
,
code
)
form
.
Set
(
"redirect_uri"
,
redirectURI
)
if
cfg
.
UsePKCE
{
form
.
Set
(
"code_verifier"
,
codeVerifier
)
}
r
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"Accept"
,
"application/json"
)
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
cfg
.
TokenAuthMethod
))
{
case
""
,
"client_secret_post"
:
form
.
Set
(
"client_secret"
,
cfg
.
ClientSecret
)
case
"client_secret_basic"
:
r
.
SetBasicAuth
(
cfg
.
ClientID
,
cfg
.
ClientSecret
)
case
"none"
:
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported token_auth_method: %s"
,
cfg
.
TokenAuthMethod
)
}
resp
,
err
:=
r
.
SetFormDataFromValues
(
form
)
.
Post
(
cfg
.
TokenURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request token: %w"
,
err
)
}
body
:=
strings
.
TrimSpace
(
resp
.
String
())
if
!
resp
.
IsSuccessState
()
{
providerErr
,
providerDesc
:=
parseOAuthProviderError
(
body
)
return
nil
,
&
linuxDoTokenExchangeError
{
StatusCode
:
resp
.
StatusCode
,
ProviderError
:
providerErr
,
ProviderDescription
:
providerDesc
,
Body
:
body
,
}
}
tokenResp
,
ok
:=
parseLinuxDoTokenResponse
(
body
)
if
!
ok
||
strings
.
TrimSpace
(
tokenResp
.
AccessToken
)
==
""
{
return
nil
,
&
linuxDoTokenExchangeError
{
StatusCode
:
resp
.
StatusCode
,
Body
:
body
,
}
}
if
strings
.
TrimSpace
(
tokenResp
.
TokenType
)
==
""
{
tokenResp
.
TokenType
=
"Bearer"
}
return
tokenResp
,
nil
}
func
linuxDoFetchUserInfo
(
ctx
context
.
Context
,
cfg
config
.
LinuxDoConnectConfig
,
token
*
linuxDoTokenResponse
,
)
(
email
string
,
username
string
,
subject
string
,
err
error
)
{
client
:=
req
.
C
()
.
SetTimeout
(
30
*
time
.
Second
)
authorization
,
err
:=
buildBearerAuthorization
(
token
.
TokenType
,
token
.
AccessToken
)
if
err
!=
nil
{
return
""
,
""
,
""
,
fmt
.
Errorf
(
"invalid token for userinfo request: %w"
,
err
)
}
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"Accept"
,
"application/json"
)
.
SetHeader
(
"Authorization"
,
authorization
)
.
Get
(
cfg
.
UserInfoURL
)
if
err
!=
nil
{
return
""
,
""
,
""
,
fmt
.
Errorf
(
"request userinfo: %w"
,
err
)
}
if
!
resp
.
IsSuccessState
()
{
return
""
,
""
,
""
,
fmt
.
Errorf
(
"userinfo status=%d"
,
resp
.
StatusCode
)
}
return
linuxDoParseUserInfo
(
resp
.
String
(),
cfg
)
}
func
linuxDoParseUserInfo
(
body
string
,
cfg
config
.
LinuxDoConnectConfig
)
(
email
string
,
username
string
,
subject
string
,
err
error
)
{
email
=
firstNonEmpty
(
getGJSON
(
body
,
cfg
.
UserInfoEmailPath
),
getGJSON
(
body
,
"email"
),
getGJSON
(
body
,
"user.email"
),
getGJSON
(
body
,
"data.email"
),
getGJSON
(
body
,
"attributes.email"
),
)
username
=
firstNonEmpty
(
getGJSON
(
body
,
cfg
.
UserInfoUsernamePath
),
getGJSON
(
body
,
"username"
),
getGJSON
(
body
,
"preferred_username"
),
getGJSON
(
body
,
"name"
),
getGJSON
(
body
,
"user.username"
),
getGJSON
(
body
,
"user.name"
),
)
subject
=
firstNonEmpty
(
getGJSON
(
body
,
cfg
.
UserInfoIDPath
),
getGJSON
(
body
,
"sub"
),
getGJSON
(
body
,
"id"
),
getGJSON
(
body
,
"user_id"
),
getGJSON
(
body
,
"uid"
),
getGJSON
(
body
,
"user.id"
),
)
subject
=
strings
.
TrimSpace
(
subject
)
if
subject
==
""
{
return
""
,
""
,
""
,
errors
.
New
(
"userinfo missing id field"
)
}
if
!
isSafeLinuxDoSubject
(
subject
)
{
return
""
,
""
,
""
,
errors
.
New
(
"userinfo returned invalid id field"
)
}
email
=
strings
.
TrimSpace
(
email
)
if
email
==
""
{
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
email
=
linuxDoSyntheticEmail
(
subject
)
}
username
=
strings
.
TrimSpace
(
username
)
if
username
==
""
{
username
=
"linuxdo_"
+
subject
}
return
email
,
username
,
subject
,
nil
}
func
buildLinuxDoAuthorizeURL
(
cfg
config
.
LinuxDoConnectConfig
,
state
string
,
codeChallenge
string
,
redirectURI
string
)
(
string
,
error
)
{
u
,
err
:=
url
.
Parse
(
cfg
.
AuthorizeURL
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"parse authorize_url: %w"
,
err
)
}
q
:=
u
.
Query
()
q
.
Set
(
"response_type"
,
"code"
)
q
.
Set
(
"client_id"
,
cfg
.
ClientID
)
q
.
Set
(
"redirect_uri"
,
redirectURI
)
if
strings
.
TrimSpace
(
cfg
.
Scopes
)
!=
""
{
q
.
Set
(
"scope"
,
cfg
.
Scopes
)
}
q
.
Set
(
"state"
,
state
)
if
cfg
.
UsePKCE
{
q
.
Set
(
"code_challenge"
,
codeChallenge
)
q
.
Set
(
"code_challenge_method"
,
"S256"
)
}
u
.
RawQuery
=
q
.
Encode
()
return
u
.
String
(),
nil
}
func
redirectOAuthError
(
c
*
gin
.
Context
,
frontendCallback
string
,
code
string
,
message
string
,
description
string
)
{
fragment
:=
url
.
Values
{}
fragment
.
Set
(
"error"
,
truncateFragmentValue
(
code
))
if
strings
.
TrimSpace
(
message
)
!=
""
{
fragment
.
Set
(
"error_message"
,
truncateFragmentValue
(
message
))
}
if
strings
.
TrimSpace
(
description
)
!=
""
{
fragment
.
Set
(
"error_description"
,
truncateFragmentValue
(
description
))
}
redirectWithFragment
(
c
,
frontendCallback
,
fragment
)
}
func
redirectWithFragment
(
c
*
gin
.
Context
,
frontendCallback
string
,
fragment
url
.
Values
)
{
u
,
err
:=
url
.
Parse
(
frontendCallback
)
if
err
!=
nil
{
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
c
.
Redirect
(
http
.
StatusFound
,
linuxDoOAuthDefaultRedirectTo
)
return
}
if
u
.
Scheme
!=
""
&&
!
strings
.
EqualFold
(
u
.
Scheme
,
"http"
)
&&
!
strings
.
EqualFold
(
u
.
Scheme
,
"https"
)
{
c
.
Redirect
(
http
.
StatusFound
,
linuxDoOAuthDefaultRedirectTo
)
return
}
u
.
Fragment
=
fragment
.
Encode
()
c
.
Header
(
"Cache-Control"
,
"no-store"
)
c
.
Header
(
"Pragma"
,
"no-cache"
)
c
.
Redirect
(
http
.
StatusFound
,
u
.
String
())
}
func
firstNonEmpty
(
values
...
string
)
string
{
for
_
,
v
:=
range
values
{
v
=
strings
.
TrimSpace
(
v
)
if
v
!=
""
{
return
v
}
}
return
""
}
func
parseOAuthProviderError
(
body
string
)
(
providerErr
string
,
providerDesc
string
)
{
body
=
strings
.
TrimSpace
(
body
)
if
body
==
""
{
return
""
,
""
}
providerErr
=
firstNonEmpty
(
getGJSON
(
body
,
"error"
),
getGJSON
(
body
,
"code"
),
getGJSON
(
body
,
"error.code"
),
)
providerDesc
=
firstNonEmpty
(
getGJSON
(
body
,
"error_description"
),
getGJSON
(
body
,
"error.message"
),
getGJSON
(
body
,
"message"
),
getGJSON
(
body
,
"detail"
),
)
if
providerErr
!=
""
||
providerDesc
!=
""
{
return
providerErr
,
providerDesc
}
values
,
err
:=
url
.
ParseQuery
(
body
)
if
err
!=
nil
{
return
""
,
""
}
providerErr
=
firstNonEmpty
(
values
.
Get
(
"error"
),
values
.
Get
(
"code"
))
providerDesc
=
firstNonEmpty
(
values
.
Get
(
"error_description"
),
values
.
Get
(
"error_message"
),
values
.
Get
(
"message"
))
return
providerErr
,
providerDesc
}
func
parseLinuxDoTokenResponse
(
body
string
)
(
*
linuxDoTokenResponse
,
bool
)
{
body
=
strings
.
TrimSpace
(
body
)
if
body
==
""
{
return
nil
,
false
}
accessToken
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"access_token"
))
if
accessToken
!=
""
{
tokenType
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"token_type"
))
refreshToken
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"refresh_token"
))
scope
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"scope"
))
expiresIn
:=
gjson
.
Get
(
body
,
"expires_in"
)
.
Int
()
return
&
linuxDoTokenResponse
{
AccessToken
:
accessToken
,
TokenType
:
tokenType
,
ExpiresIn
:
expiresIn
,
RefreshToken
:
refreshToken
,
Scope
:
scope
,
},
true
}
values
,
err
:=
url
.
ParseQuery
(
body
)
if
err
!=
nil
{
return
nil
,
false
}
accessToken
=
strings
.
TrimSpace
(
values
.
Get
(
"access_token"
))
if
accessToken
==
""
{
return
nil
,
false
}
expiresIn
:=
int64
(
0
)
if
raw
:=
strings
.
TrimSpace
(
values
.
Get
(
"expires_in"
));
raw
!=
""
{
if
v
,
err
:=
strconv
.
ParseInt
(
raw
,
10
,
64
);
err
==
nil
{
expiresIn
=
v
}
}
return
&
linuxDoTokenResponse
{
AccessToken
:
accessToken
,
TokenType
:
strings
.
TrimSpace
(
values
.
Get
(
"token_type"
)),
ExpiresIn
:
expiresIn
,
RefreshToken
:
strings
.
TrimSpace
(
values
.
Get
(
"refresh_token"
)),
Scope
:
strings
.
TrimSpace
(
values
.
Get
(
"scope"
)),
},
true
}
func
getGJSON
(
body
string
,
path
string
)
string
{
path
=
strings
.
TrimSpace
(
path
)
if
path
==
""
{
return
""
}
res
:=
gjson
.
Get
(
body
,
path
)
if
!
res
.
Exists
()
{
return
""
}
return
res
.
String
()
}
func
truncateLogValue
(
value
string
,
maxLen
int
)
string
{
value
=
strings
.
TrimSpace
(
value
)
if
value
==
""
||
maxLen
<=
0
{
return
""
}
if
len
(
value
)
<=
maxLen
{
return
value
}
value
=
value
[
:
maxLen
]
for
!
utf8
.
ValidString
(
value
)
{
value
=
value
[
:
len
(
value
)
-
1
]
}
return
value
}
func
singleLine
(
value
string
)
string
{
value
=
strings
.
TrimSpace
(
value
)
if
value
==
""
{
return
""
}
return
strings
.
Join
(
strings
.
Fields
(
value
),
" "
)
}
func
sanitizeFrontendRedirectPath
(
path
string
)
string
{
path
=
strings
.
TrimSpace
(
path
)
if
path
==
""
{
return
""
}
if
len
(
path
)
>
linuxDoOAuthMaxRedirectLen
{
return
""
}
// 只允许同源相对路径(避免开放重定向)。
if
!
strings
.
HasPrefix
(
path
,
"/"
)
{
return
""
}
if
strings
.
HasPrefix
(
path
,
"//"
)
{
return
""
}
if
strings
.
Contains
(
path
,
"://"
)
{
return
""
}
if
strings
.
ContainsAny
(
path
,
"
\r\n
"
)
{
return
""
}
return
path
}
func
isRequestHTTPS
(
c
*
gin
.
Context
)
bool
{
if
c
.
Request
.
TLS
!=
nil
{
return
true
}
proto
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
GetHeader
(
"X-Forwarded-Proto"
)))
return
proto
==
"https"
}
func
encodeCookieValue
(
value
string
)
string
{
return
base64
.
RawURLEncoding
.
EncodeToString
([]
byte
(
value
))
}
func
decodeCookieValue
(
value
string
)
(
string
,
error
)
{
raw
,
err
:=
base64
.
RawURLEncoding
.
DecodeString
(
value
)
if
err
!=
nil
{
return
""
,
err
}
return
string
(
raw
),
nil
}
func
readCookieDecoded
(
c
*
gin
.
Context
,
name
string
)
(
string
,
error
)
{
ck
,
err
:=
c
.
Request
.
Cookie
(
name
)
if
err
!=
nil
{
return
""
,
err
}
return
decodeCookieValue
(
ck
.
Value
)
}
func
setCookie
(
c
*
gin
.
Context
,
name
string
,
value
string
,
maxAgeSec
int
,
secure
bool
)
{
http
.
SetCookie
(
c
.
Writer
,
&
http
.
Cookie
{
Name
:
name
,
Value
:
value
,
Path
:
linuxDoOAuthCookiePath
,
MaxAge
:
maxAgeSec
,
HttpOnly
:
true
,
Secure
:
secure
,
SameSite
:
http
.
SameSiteLaxMode
,
})
}
func
clearCookie
(
c
*
gin
.
Context
,
name
string
,
secure
bool
)
{
http
.
SetCookie
(
c
.
Writer
,
&
http
.
Cookie
{
Name
:
name
,
Value
:
""
,
Path
:
linuxDoOAuthCookiePath
,
MaxAge
:
-
1
,
HttpOnly
:
true
,
Secure
:
secure
,
SameSite
:
http
.
SameSiteLaxMode
,
})
}
func
truncateFragmentValue
(
value
string
)
string
{
value
=
strings
.
TrimSpace
(
value
)
if
value
==
""
{
return
""
}
if
len
(
value
)
>
linuxDoOAuthMaxFragmentValueLen
{
value
=
value
[
:
linuxDoOAuthMaxFragmentValueLen
]
for
!
utf8
.
ValidString
(
value
)
{
value
=
value
[
:
len
(
value
)
-
1
]
}
}
return
value
}
func
buildBearerAuthorization
(
tokenType
,
accessToken
string
)
(
string
,
error
)
{
tokenType
=
strings
.
TrimSpace
(
tokenType
)
if
tokenType
==
""
{
tokenType
=
"Bearer"
}
if
!
strings
.
EqualFold
(
tokenType
,
"Bearer"
)
{
return
""
,
fmt
.
Errorf
(
"unsupported token_type: %s"
,
tokenType
)
}
accessToken
=
strings
.
TrimSpace
(
accessToken
)
if
accessToken
==
""
{
return
""
,
errors
.
New
(
"missing access_token"
)
}
if
strings
.
ContainsAny
(
accessToken
,
"
\t\r\n
"
)
{
return
""
,
errors
.
New
(
"access_token contains whitespace"
)
}
return
"Bearer "
+
accessToken
,
nil
}
func
isSafeLinuxDoSubject
(
subject
string
)
bool
{
subject
=
strings
.
TrimSpace
(
subject
)
if
subject
==
""
||
len
(
subject
)
>
linuxDoOAuthMaxSubjectLen
{
return
false
}
for
_
,
r
:=
range
subject
{
switch
{
case
r
>=
'0'
&&
r
<=
'9'
:
case
r
>=
'a'
&&
r
<=
'z'
:
case
r
>=
'A'
&&
r
<=
'Z'
:
case
r
==
'_'
||
r
==
'-'
:
default
:
return
false
}
}
return
true
}
func
linuxDoSyntheticEmail
(
subject
string
)
string
{
subject
=
strings
.
TrimSpace
(
subject
)
if
subject
==
""
{
return
""
}
return
"linuxdo-"
+
subject
+
service
.
LinuxDoConnectSyntheticEmailDomain
}
backend/internal/handler/auth_linuxdo_oauth_test.go
0 → 100644
View file @
1a641392
package
handler
import
(
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestSanitizeFrontendRedirectPath
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"/dashboard"
,
sanitizeFrontendRedirectPath
(
"/dashboard"
))
require
.
Equal
(
t
,
"/dashboard"
,
sanitizeFrontendRedirectPath
(
" /dashboard "
))
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
"dashboard"
))
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
"//evil.com"
))
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
"https://evil.com"
))
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
"/
\n
foo"
))
long
:=
"/"
+
strings
.
Repeat
(
"a"
,
linuxDoOAuthMaxRedirectLen
)
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
long
))
}
func
TestBuildBearerAuthorization
(
t
*
testing
.
T
)
{
auth
,
err
:=
buildBearerAuthorization
(
""
,
"token123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"Bearer token123"
,
auth
)
auth
,
err
=
buildBearerAuthorization
(
"bearer"
,
"token123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"Bearer token123"
,
auth
)
_
,
err
=
buildBearerAuthorization
(
"MAC"
,
"token123"
)
require
.
Error
(
t
,
err
)
_
,
err
=
buildBearerAuthorization
(
"Bearer"
,
"token 123"
)
require
.
Error
(
t
,
err
)
}
func
TestLinuxDoParseUserInfoParsesIDAndUsername
(
t
*
testing
.
T
)
{
cfg
:=
config
.
LinuxDoConnectConfig
{
UserInfoURL
:
"https://connect.linux.do/api/user"
,
}
email
,
username
,
subject
,
err
:=
linuxDoParseUserInfo
(
`{"id":123,"username":"alice"}`
,
cfg
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"123"
,
subject
)
require
.
Equal
(
t
,
"alice"
,
username
)
require
.
Equal
(
t
,
"linuxdo-123@linuxdo-connect.invalid"
,
email
)
}
func
TestLinuxDoParseUserInfoDefaultsUsername
(
t
*
testing
.
T
)
{
cfg
:=
config
.
LinuxDoConnectConfig
{
UserInfoURL
:
"https://connect.linux.do/api/user"
,
}
email
,
username
,
subject
,
err
:=
linuxDoParseUserInfo
(
`{"id":"123"}`
,
cfg
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"123"
,
subject
)
require
.
Equal
(
t
,
"linuxdo_123"
,
username
)
require
.
Equal
(
t
,
"linuxdo-123@linuxdo-connect.invalid"
,
email
)
}
func
TestLinuxDoParseUserInfoRejectsUnsafeSubject
(
t
*
testing
.
T
)
{
cfg
:=
config
.
LinuxDoConnectConfig
{
UserInfoURL
:
"https://connect.linux.do/api/user"
,
}
_
,
_
,
_
,
err
:=
linuxDoParseUserInfo
(
`{"id":"123@456"}`
,
cfg
)
require
.
Error
(
t
,
err
)
tooLong
:=
strings
.
Repeat
(
"a"
,
linuxDoOAuthMaxSubjectLen
+
1
)
_
,
_
,
_
,
err
=
linuxDoParseUserInfo
(
`{"id":"`
+
tooLong
+
`"}`
,
cfg
)
require
.
Error
(
t
,
err
)
}
func
TestParseOAuthProviderErrorJSON
(
t
*
testing
.
T
)
{
code
,
desc
:=
parseOAuthProviderError
(
`{"error":"invalid_client","error_description":"bad secret"}`
)
require
.
Equal
(
t
,
"invalid_client"
,
code
)
require
.
Equal
(
t
,
"bad secret"
,
desc
)
}
func
TestParseOAuthProviderErrorForm
(
t
*
testing
.
T
)
{
code
,
desc
:=
parseOAuthProviderError
(
"error=invalid_request&error_description=Missing+code_verifier"
)
require
.
Equal
(
t
,
"invalid_request"
,
code
)
require
.
Equal
(
t
,
"Missing code_verifier"
,
desc
)
}
func
TestParseLinuxDoTokenResponseJSON
(
t
*
testing
.
T
)
{
token
,
ok
:=
parseLinuxDoTokenResponse
(
`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"t1"
,
token
.
AccessToken
)
require
.
Equal
(
t
,
"Bearer"
,
token
.
TokenType
)
require
.
Equal
(
t
,
int64
(
3600
),
token
.
ExpiresIn
)
require
.
Equal
(
t
,
"user"
,
token
.
Scope
)
}
func
TestParseLinuxDoTokenResponseForm
(
t
*
testing
.
T
)
{
token
,
ok
:=
parseLinuxDoTokenResponse
(
"access_token=t2&token_type=bearer&expires_in=60"
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"t2"
,
token
.
AccessToken
)
require
.
Equal
(
t
,
"bearer"
,
token
.
TokenType
)
require
.
Equal
(
t
,
int64
(
60
),
token
.
ExpiresIn
)
}
func
TestSingleLineStripsWhitespace
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"hello world"
,
singleLine
(
"hello
\r\n
world"
))
require
.
Equal
(
t
,
""
,
singleLine
(
"
\n\t\r
"
))
}
backend/internal/handler/dto/mappers.go
View file @
1a641392
...
...
@@ -59,6 +59,8 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Name
:
k
.
Name
,
GroupID
:
k
.
GroupID
,
Status
:
k
.
Status
,
IPWhitelist
:
k
.
IPWhitelist
,
IPBlacklist
:
k
.
IPBlacklist
,
CreatedAt
:
k
.
CreatedAt
,
UpdatedAt
:
k
.
UpdatedAt
,
User
:
UserFromServiceShallow
(
k
.
User
),
...
...
@@ -250,11 +252,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
func
usageLogFromServiceBase
(
l
*
service
.
UsageLog
,
account
*
AccountSummary
)
*
UsageLog
{
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
func
usageLogFromServiceBase
(
l
*
service
.
UsageLog
,
account
*
AccountSummary
,
includeIPAddress
bool
)
*
UsageLog
{
if
l
==
nil
{
return
nil
}
re
turn
&
UsageLog
{
re
sult
:=
&
UsageLog
{
ID
:
l
.
ID
,
UserID
:
l
.
UserID
,
APIKeyID
:
l
.
APIKeyID
,
...
...
@@ -290,21 +293,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
Group
:
GroupFromServiceShallow
(
l
.
Group
),
Subscription
:
UserSubscriptionFromService
(
l
.
Subscription
),
}
// IP 地址仅对管理员可见
if
includeIPAddress
{
result
.
IPAddress
=
l
.
IPAddress
}
return
result
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details - users should not see
account information
.
// It excludes Account details
and IP address
- users should not see
these
.
func
UsageLogFromService
(
l
*
service
.
UsageLog
)
*
UsageLog
{
return
usageLogFromServiceBase
(
l
,
nil
)
return
usageLogFromServiceBase
(
l
,
nil
,
false
)
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only).
// It includes minimal Account info (ID, Name only)
and IP address
.
func
UsageLogFromServiceAdmin
(
l
*
service
.
UsageLog
)
*
UsageLog
{
if
l
==
nil
{
return
nil
}
return
usageLogFromServiceBase
(
l
,
AccountSummaryFromService
(
l
.
Account
))
return
usageLogFromServiceBase
(
l
,
AccountSummaryFromService
(
l
.
Account
)
,
true
)
}
func
SettingFromService
(
s
*
service
.
Setting
)
*
Setting
{
...
...
@@ -362,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
Errors
:
r
.
Errors
,
}
}
func
PromoCodeFromService
(
pc
*
service
.
PromoCode
)
*
PromoCode
{
if
pc
==
nil
{
return
nil
}
return
&
PromoCode
{
ID
:
pc
.
ID
,
Code
:
pc
.
Code
,
BonusAmount
:
pc
.
BonusAmount
,
MaxUses
:
pc
.
MaxUses
,
UsedCount
:
pc
.
UsedCount
,
Status
:
pc
.
Status
,
ExpiresAt
:
pc
.
ExpiresAt
,
Notes
:
pc
.
Notes
,
CreatedAt
:
pc
.
CreatedAt
,
UpdatedAt
:
pc
.
UpdatedAt
,
}
}
func
PromoCodeUsageFromService
(
u
*
service
.
PromoCodeUsage
)
*
PromoCodeUsage
{
if
u
==
nil
{
return
nil
}
return
&
PromoCodeUsage
{
ID
:
u
.
ID
,
PromoCodeID
:
u
.
PromoCodeID
,
UserID
:
u
.
UserID
,
BonusAmount
:
u
.
BonusAmount
,
UsedAt
:
u
.
UsedAt
,
User
:
UserFromServiceShallow
(
u
.
User
),
}
}
backend/internal/handler/dto/settings.go
View file @
1a641392
...
...
@@ -17,6 +17,11 @@ type SystemSettings struct {
TurnstileSiteKey
string
`json:"turnstile_site_key"`
TurnstileSecretKeyConfigured
bool
`json:"turnstile_secret_key_configured"`
LinuxDoConnectEnabled
bool
`json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID
string
`json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecretConfigured
bool
`json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL
string
`json:"linuxdo_connect_redirect_url"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteSubtitle
string
`json:"site_subtitle"`
...
...
@@ -50,5 +55,6 @@ type PublicSettings struct {
APIBaseURL
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
DocURL
string
`json:"doc_url"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
Version
string
`json:"version"`
}
backend/internal/handler/dto/types.go
View file @
1a641392
...
...
@@ -26,6 +26,8 @@ type APIKey struct {
Name
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
Status
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
IPBlacklist
[]
string
`json:"ip_blacklist"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
...
...
@@ -187,6 +189,9 @@ type UsageLog struct {
// User-Agent
UserAgent
*
string
`json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress
*
string
`json:"ip_address,omitempty"`
CreatedAt
time
.
Time
`json:"created_at"`
User
*
User
`json:"user,omitempty"`
...
...
@@ -245,3 +250,28 @@ type BulkAssignResult struct {
Subscriptions
[]
UserSubscription
`json:"subscriptions"`
Errors
[]
string
`json:"errors"`
}
// PromoCode 注册优惠码
type
PromoCode
struct
{
ID
int64
`json:"id"`
Code
string
`json:"code"`
BonusAmount
float64
`json:"bonus_amount"`
MaxUses
int
`json:"max_uses"`
UsedCount
int
`json:"used_count"`
Status
string
`json:"status"`
ExpiresAt
*
time
.
Time
`json:"expires_at"`
Notes
string
`json:"notes"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
}
// PromoCodeUsage 优惠码使用记录
type
PromoCodeUsage
struct
{
ID
int64
`json:"id"`
PromoCodeID
int64
`json:"promo_code_id"`
UserID
int64
`json:"user_id"`
BonusAmount
float64
`json:"bonus_amount"`
UsedAt
time
.
Time
`json:"used_at"`
User
*
User
`json:"user,omitempty"`
}
backend/internal/handler/gateway_handler.go
View file @
1a641392
...
...
@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -114,6 +115,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 获取 User-Agent
userAgent
:=
c
.
Request
.
UserAgent
()
// 获取客户端 IP
clientIP
:=
ip
.
GetClientIP
(
c
)
// 0. 检查wait队列是否已满
maxWait
:=
service
.
CalculateMaxWait
(
subject
.
Concurrency
)
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
,
maxWait
)
...
...
@@ -273,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 异步记录使用量(subscription已在函数开头获取)
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
)
{
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
,
cip
string
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
...
...
@@ -283,10 +287,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account
:
usedAccount
,
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
cip
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
,
userAgent
)
}(
result
,
account
,
userAgent
,
clientIP
)
return
}
}
...
...
@@ -401,7 +406,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 异步记录使用量(subscription已在函数开头获取)
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
)
{
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
,
cip
string
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
...
...
@@ -411,10 +416,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account
:
usedAccount
,
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
cip
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
,
userAgent
)
}(
result
,
account
,
userAgent
,
clientIP
)
return
}
}
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
1a641392
...
...
@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -167,6 +168,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 获取 User-Agent
userAgent
:=
c
.
Request
.
UserAgent
()
// 获取客户端 IP
clientIP
:=
ip
.
GetClientIP
(
c
)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency
:=
NewConcurrencyHelper
(
h
.
concurrencyHelper
.
concurrencyService
,
SSEPingFormatNone
,
0
)
...
...
@@ -307,7 +311,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 6) record usage async
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
)
{
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
,
cip
string
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
...
...
@@ -317,10 +321,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Account
:
usedAccount
,
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
cip
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
,
userAgent
)
}(
result
,
account
,
userAgent
,
clientIP
)
return
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment