Commit 11bfc807 authored by song's avatar song
Browse files

merge upstream/main

parents c2a6ca8d 62dc0b95
# Linux DO Connect
OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。
目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。
## 基本介绍
这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。
- 可获取字段:
| 参数 | 说明 |
| ----------------- | ------------------------------- |
| `id` | 用户唯一标识(不可变) |
| `username` | 论坛用户名 |
| `name` | 论坛用户昵称(可变) |
| `avatar_template` | 用户头像模板URL(支持多种尺寸) |
| `active` | 账号活跃状态 |
| `trust_level` | 信任等级(0-4) |
| `silenced` | 禁言状态 |
| `external_ids` | 外部ID关联信息 |
| `api_key` | API访问密钥 |
通过这些信息,公益网站/接口可以实现:
1. 基于 `id` 的服务频率限制
2. 基于 `trust_level` 的服务额度分配
3. 基于用户信息的滥用举报机制
## 相关端点
- Authorize 端点: `https://connect.linux.do/oauth2/authorize`
- Token 端点:`https://connect.linux.do/oauth2/token`
- 用户信息 端点:`https://connect.linux.do/api/user`
## 申请使用
- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。
![linuxdoconnect_1](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_1.png&w=1080&q=75)
- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。
![linuxdoconnect_2](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_2.png&w=1080&q=75)
- 申请成功后,你将获得 **`Client Id`****`Client Secret`**,这是你应用的唯一身份凭证。
![linuxdoconnect_3](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_3.png&w=1080&q=75)
## 接入 Linux Do
JavaScript
```JavaScript
// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios
// npm install axios
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
const axios = require('axios');
const readline = require('readline');
// 配置信息(建议通过环境变量配置,避免使用硬编码)
const CLIENT_ID = '你的 Client ID';
const CLIENT_SECRET = '你的 Client Secret';
const REDIRECT_URI = '你的回调地址';
const AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
const TOKEN_URL = 'https://connect.linux.do/oauth2/token';
const USER_INFO_URL = 'https://connect.linux.do/api/user';
// 第一步:生成授权 URL
function getAuthUrl() {
const params = new URLSearchParams({
client_id: CLIENT_ID,
redirect_uri: REDIRECT_URI,
response_type: 'code',
scope: 'user'
});
return `${AUTH_URL}?${params.toString()}`;
}
// 第二步:获取 code 参数
function getCode() {
return new Promise((resolve) => {
// 本例中使用终端输入来模拟流程,仅供本地测试
// 请在实际应用中替换为真实的处理逻辑
const rl = readline.createInterface({ input: process.stdin, output: process.stdout });
rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => {
rl.close();
resolve(answer.trim());
});
});
}
// 第三步:使用 code 参数获取访问令牌
async function getAccessToken(code) {
try {
const form = new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code: code,
redirect_uri: REDIRECT_URI,
grant_type: 'authorization_code'
}).toString();
const response = await axios.post(TOKEN_URL, form, {
// 提醒:需正确配置请求头,否则无法正常获取访问令牌
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
});
return response.data;
} catch (error) {
console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
throw error;
}
}
// 第四步:使用访问令牌获取用户信息
async function getUserInfo(accessToken) {
try {
const response = await axios.get(USER_INFO_URL, {
headers: {
Authorization: `Bearer ${accessToken}`
}
});
return response.data;
} catch (error) {
console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
throw error;
}
}
// 主流程
async function main() {
// 1. 生成授权 URL,前端引导用户访问授权页
const authUrl = getAuthUrl();
console.log(`请访问此 URL 授权:${authUrl}
`);
// 2. 用户授权后,从回调 URL 获取 code 参数
const code = await getCode();
try {
// 3. 使用 code 参数获取访问令牌
const tokenData = await getAccessToken(code);
const accessToken = tokenData.access_token;
// 4. 使用访问令牌获取用户信息
if (accessToken) {
const userInfo = await getUserInfo(accessToken);
console.log(`
获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`);
} else {
console.log(`
获取访问令牌失败:${JSON.stringify(tokenData)}`);
}
} catch (error) {
console.error('发生错误:', error);
}
}
```
Python
```python
# 安装第三方请求库,本例中使用 requests
# pip install requests
# 通过 OAuth2 获取 Linux Do 用户信息的参考流程
import requests
import json
# 配置信息(建议通过环境变量配置,避免使用硬编码)
CLIENT_ID = '你的 Client ID'
CLIENT_SECRET = '你的 Client Secret'
REDIRECT_URI = '你的回调地址'
AUTH_URL = 'https://connect.linux.do/oauth2/authorize'
TOKEN_URL = 'https://connect.linux.do/oauth2/token'
USER_INFO_URL = 'https://connect.linux.do/api/user'
# 第一步:生成授权 URL
def get_auth_url():
params = {
'client_id': CLIENT_ID,
'redirect_uri': REDIRECT_URI,
'response_type': 'code',
'scope': 'user'
}
auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}"
return auth_url
# 第二步:获取 code 参数
def get_code():
# 本例中使用终端输入来模拟流程,仅供本地测试
# 请在实际应用中替换为真实的处理逻辑
return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip()
# 第三步:使用 code 参数获取访问令牌
def get_access_token(code):
try:
data = {
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'code': code,
'redirect_uri': REDIRECT_URI,
'grant_type': 'authorization_code'
}
# 提醒:需正确配置请求头,否则无法正常获取访问令牌
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
response = requests.post(TOKEN_URL, data=data, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"获取访问令牌失败:{e}")
return None
# 第四步:使用访问令牌获取用户信息
def get_user_info(access_token):
try:
headers = {
'Authorization': f'Bearer {access_token}'
}
response = requests.get(USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"获取用户信息失败:{e}")
return None
# 主流程
if __name__ == '__main__':
# 1. 生成授权 URL,前端引导用户访问授权页
auth_url = get_auth_url()
print(f'请访问此 URL 授权:{auth_url}
')
# 2. 用户授权后,从回调 URL 获取 code 参数
code = get_code()
# 3. 使用 code 参数获取访问令牌
token_data = get_access_token(code)
if token_data:
access_token = token_data.get('access_token')
# 4. 使用访问令牌获取用户信息
if access_token:
user_info = get_user_info(access_token)
if user_info:
print(f"
获取用户信息成功:{json.dumps(user_info, indent=2)}")
else:
print("
获取用户信息失败")
else:
print(f"
获取访问令牌失败:{json.dumps(token_data, indent=2)}")
else:
print("
获取访问令牌失败")
```
PHP
```php
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
// 配置信息
$CLIENT_ID = '你的 Client ID';
$CLIENT_SECRET = '你的 Client Secret';
$REDIRECT_URI = '你的回调地址';
$AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
$TOKEN_URL = 'https://connect.linux.do/oauth2/token';
$USER_INFO_URL = 'https://connect.linux.do/api/user';
// 生成授权 URL
function getAuthUrl($clientId, $redirectUri) {
global $AUTH_URL;
return $AUTH_URL . '?' . http_build_query([
'client_id' => $clientId,
'redirect_uri' => $redirectUri,
'response_type' => 'code',
'scope' => 'user'
]);
}
// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤)
function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) {
global $TOKEN_URL, $USER_INFO_URL;
// 1. 获取访问令牌
$ch = curl_init($TOKEN_URL);
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
curl_setopt($ch, CURLOPT_POST, true);
curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([
'client_id' => $clientId,
'client_secret' => $clientSecret,
'code' => $code,
'redirect_uri' => $redirectUri,
'grant_type' => 'authorization_code'
]));
curl_setopt($ch, CURLOPT_HTTPHEADER, [
'Content-Type: application/x-www-form-urlencoded',
'Accept: application/json'
]);
$tokenResponse = curl_exec($ch);
curl_close($ch);
$tokenData = json_decode($tokenResponse, true);
if (!isset($tokenData['access_token'])) {
return ['error' => '获取访问令牌失败', 'details' => $tokenData];
}
// 2. 获取用户信息
$ch = curl_init($USER_INFO_URL);
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
curl_setopt($ch, CURLOPT_HTTPHEADER, [
'Authorization: Bearer ' . $tokenData['access_token']
]);
$userResponse = curl_exec($ch);
curl_close($ch);
return json_decode($userResponse, true);
}
// 主流程
// 1. 生成授权 URL
$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI);
echo "<a href='$authUrl'>使用 Linux Do 登录</a>";
// 2. 处理回调并获取用户信息
if (isset($_GET['code'])) {
$userInfo = getUserInfoWithCode(
$_GET['code'],
$CLIENT_ID,
$CLIENT_SECRET,
$REDIRECT_URI
);
if (isset($userInfo['error'])) {
echo '错误: ' . $userInfo['error'];
} else {
echo '欢迎, ' . $userInfo['name'] . '!';
// 处理用户登录逻辑...
}
}
```
## 使用说明
### 授权流程
1. 用户点击应用中的’使用 Linux Do 登录’按钮
2. 系统将用户重定向至 Linux Do 的授权页面
3. 用户完成授权后,系统自动重定向回应用并携带授权码
4. 应用使用授权码获取访问令牌
5. 使用访问令牌获取用户信息
### 安全建议
- 切勿在前端代码中暴露 Client Secret
- 对所有用户输入数据进行严格验证
- 确保使用 HTTPS 协议传输数据
- 定期更新并妥善保管 Client Secret
\ No newline at end of file
...@@ -53,7 +53,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -53,7 +53,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
emailQueueService := service.ProvideEmailQueueService(emailService) emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
userService := service.NewUserService(userRepository) userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewAPIKeyRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db) groupRepository := repository.NewGroupRepository(client, db)
......
...@@ -51,6 +51,10 @@ type Group struct { ...@@ -51,6 +51,10 @@ type Group struct {
ImagePrice2k *float64 `json:"image_price_2k,omitempty"` ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
// ImagePrice4k holds the value of the "image_price_4k" field. // ImagePrice4k holds the value of the "image_price_4k" field.
ImagePrice4k *float64 `json:"image_price_4k,omitempty"` ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
// 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set. // The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"` Edges GroupEdges `json:"edges"`
...@@ -157,11 +161,11 @@ func (*Group) scanValues(columns []string) ([]any, error) { ...@@ -157,11 +161,11 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case group.FieldIsExclusive: case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays: case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
...@@ -298,6 +302,19 @@ func (_m *Group) assignValues(columns []string, values []any) error { ...@@ -298,6 +302,19 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.ImagePrice4k = new(float64) _m.ImagePrice4k = new(float64)
*_m.ImagePrice4k = value.Float64 *_m.ImagePrice4k = value.Float64
} }
case group.FieldClaudeCodeOnly:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
} else if value.Valid {
_m.ClaudeCodeOnly = value.Bool
}
case group.FieldFallbackGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field fallback_group_id", values[i])
} else if value.Valid {
_m.FallbackGroupID = new(int64)
*_m.FallbackGroupID = value.Int64
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
...@@ -440,6 +457,14 @@ func (_m *Group) String() string { ...@@ -440,6 +457,14 @@ func (_m *Group) String() string {
builder.WriteString("image_price_4k=") builder.WriteString("image_price_4k=")
builder.WriteString(fmt.Sprintf("%v", *v)) builder.WriteString(fmt.Sprintf("%v", *v))
} }
builder.WriteString(", ")
builder.WriteString("claude_code_only=")
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
builder.WriteString(", ")
if v := _m.FallbackGroupID; v != nil {
builder.WriteString("fallback_group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }
......
...@@ -49,6 +49,10 @@ const ( ...@@ -49,6 +49,10 @@ const (
FieldImagePrice2k = "image_price_2k" FieldImagePrice2k = "image_price_2k"
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database. // FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
FieldImagePrice4k = "image_price_4k" FieldImagePrice4k = "image_price_4k"
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
FieldFallbackGroupID = "fallback_group_id"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys" EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
...@@ -141,6 +145,8 @@ var Columns = []string{ ...@@ -141,6 +145,8 @@ var Columns = []string{
FieldImagePrice1k, FieldImagePrice1k,
FieldImagePrice2k, FieldImagePrice2k,
FieldImagePrice4k, FieldImagePrice4k,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
} }
var ( var (
...@@ -196,6 +202,8 @@ var ( ...@@ -196,6 +202,8 @@ var (
SubscriptionTypeValidator func(string) error SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int DefaultDefaultValidityDays int
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
) )
// OrderOption defines the ordering options for the Group queries. // OrderOption defines the ordering options for the Group queries.
...@@ -291,6 +299,16 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { ...@@ -291,6 +299,16 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
} }
// ByClaudeCodeOnly orders the results by the claude_code_only field.
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
}
// ByFallbackGroupID orders the results by the fallback_group_id field.
func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count. // ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) { return func(s *sql.Selector) {
......
...@@ -140,6 +140,16 @@ func ImagePrice4k(v float64) predicate.Group { ...@@ -140,6 +140,16 @@ func ImagePrice4k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
} }
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
func ClaudeCodeOnly(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
}
// FallbackGroupID applies equality check predicate on the "fallback_group_id" field. It's identical to FallbackGroupIDEQ.
func FallbackGroupID(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group { func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
...@@ -995,6 +1005,66 @@ func ImagePrice4kNotNil() predicate.Group { ...@@ -995,6 +1005,66 @@ func ImagePrice4kNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
} }
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
}
// ClaudeCodeOnlyNEQ applies the NEQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldClaudeCodeOnly, v))
}
// FallbackGroupIDEQ applies the EQ predicate on the "fallback_group_id" field.
func FallbackGroupIDEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
// FallbackGroupIDNEQ applies the NEQ predicate on the "fallback_group_id" field.
func FallbackGroupIDNEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupID, v))
}
// FallbackGroupIDIn applies the In predicate on the "fallback_group_id" field.
func FallbackGroupIDIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldFallbackGroupID, vs...))
}
// FallbackGroupIDNotIn applies the NotIn predicate on the "fallback_group_id" field.
func FallbackGroupIDNotIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupID, vs...))
}
// FallbackGroupIDGT applies the GT predicate on the "fallback_group_id" field.
func FallbackGroupIDGT(v int64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldFallbackGroupID, v))
}
// FallbackGroupIDGTE applies the GTE predicate on the "fallback_group_id" field.
func FallbackGroupIDGTE(v int64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldFallbackGroupID, v))
}
// FallbackGroupIDLT applies the LT predicate on the "fallback_group_id" field.
func FallbackGroupIDLT(v int64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldFallbackGroupID, v))
}
// FallbackGroupIDLTE applies the LTE predicate on the "fallback_group_id" field.
func FallbackGroupIDLTE(v int64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldFallbackGroupID, v))
}
// FallbackGroupIDIsNil applies the IsNil predicate on the "fallback_group_id" field.
func FallbackGroupIDIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupID))
}
// FallbackGroupIDNotNil applies the NotNil predicate on the "fallback_group_id" field.
func FallbackGroupIDNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group { func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) { return predicate.Group(func(s *sql.Selector) {
......
...@@ -258,6 +258,34 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { ...@@ -258,6 +258,34 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
return _c return _c
} }
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
_c.mutation.SetClaudeCodeOnly(v)
return _c
}
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
func (_c *GroupCreate) SetNillableClaudeCodeOnly(v *bool) *GroupCreate {
if v != nil {
_c.SetClaudeCodeOnly(*v)
}
return _c
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (_c *GroupCreate) SetFallbackGroupID(v int64) *GroupCreate {
_c.mutation.SetFallbackGroupID(v)
return _c
}
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
if v != nil {
_c.SetFallbackGroupID(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...) _c.mutation.AddAPIKeyIDs(ids...)
...@@ -423,6 +451,10 @@ func (_c *GroupCreate) defaults() error { ...@@ -423,6 +451,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultValidityDays v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v) _c.mutation.SetDefaultValidityDays(v)
} }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
}
return nil return nil
} }
...@@ -475,6 +507,9 @@ func (_c *GroupCreate) check() error { ...@@ -475,6 +507,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.DefaultValidityDays(); !ok { if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
} }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
return nil return nil
} }
...@@ -570,6 +605,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { ...@@ -570,6 +605,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
_node.ImagePrice4k = &value _node.ImagePrice4k = &value
} }
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
_node.ClaudeCodeOnly = value
}
if value, ok := _c.mutation.FallbackGroupID(); ok {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
_node.FallbackGroupID = &value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
...@@ -1014,6 +1057,42 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { ...@@ -1014,6 +1057,42 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
return u return u
} }
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
u.Set(group.FieldClaudeCodeOnly, v)
return u
}
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
func (u *GroupUpsert) UpdateClaudeCodeOnly() *GroupUpsert {
u.SetExcluded(group.FieldClaudeCodeOnly)
return u
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (u *GroupUpsert) SetFallbackGroupID(v int64) *GroupUpsert {
u.Set(group.FieldFallbackGroupID, v)
return u
}
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
func (u *GroupUpsert) UpdateFallbackGroupID() *GroupUpsert {
u.SetExcluded(group.FieldFallbackGroupID)
return u
}
// AddFallbackGroupID adds v to the "fallback_group_id" field.
func (u *GroupUpsert) AddFallbackGroupID(v int64) *GroupUpsert {
u.Add(group.FieldFallbackGroupID, v)
return u
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
u.SetNull(group.FieldFallbackGroupID)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
...@@ -1395,6 +1474,48 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { ...@@ -1395,6 +1474,48 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
}) })
} }
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetClaudeCodeOnly(v)
})
}
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateClaudeCodeOnly() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateClaudeCodeOnly()
})
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (u *GroupUpsertOne) SetFallbackGroupID(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupID(v)
})
}
// AddFallbackGroupID adds v to the "fallback_group_id" field.
func (u *GroupUpsertOne) AddFallbackGroupID(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupID(v)
})
}
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateFallbackGroupID() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupID()
})
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupID()
})
}
// Exec executes the query. // Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error { func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
...@@ -1942,6 +2063,48 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { ...@@ -1942,6 +2063,48 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
}) })
} }
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetClaudeCodeOnly(v)
})
}
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateClaudeCodeOnly() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateClaudeCodeOnly()
})
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (u *GroupUpsertBulk) SetFallbackGroupID(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupID(v)
})
}
// AddFallbackGroupID adds v to the "fallback_group_id" field.
func (u *GroupUpsertBulk) AddFallbackGroupID(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupID(v)
})
}
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateFallbackGroupID() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupID()
})
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupID()
})
}
// Exec executes the query. // Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error { func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {
......
...@@ -354,6 +354,47 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { ...@@ -354,6 +354,47 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
return _u return _u
} }
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
_u.mutation.SetClaudeCodeOnly(v)
return _u
}
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableClaudeCodeOnly(v *bool) *GroupUpdate {
if v != nil {
_u.SetClaudeCodeOnly(*v)
}
return _u
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (_u *GroupUpdate) SetFallbackGroupID(v int64) *GroupUpdate {
_u.mutation.ResetFallbackGroupID()
_u.mutation.SetFallbackGroupID(v)
return _u
}
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableFallbackGroupID(v *int64) *GroupUpdate {
if v != nil {
_u.SetFallbackGroupID(*v)
}
return _u
}
// AddFallbackGroupID adds value to the "fallback_group_id" field.
func (_u *GroupUpdate) AddFallbackGroupID(v int64) *GroupUpdate {
_u.mutation.AddFallbackGroupID(v)
return _u
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
_u.mutation.ClearFallbackGroupID()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
...@@ -750,6 +791,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -750,6 +791,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImagePrice4kCleared() { if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
} }
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
if value, ok := _u.mutation.FallbackGroupID(); ok {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
...@@ -1384,6 +1437,47 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { ...@@ -1384,6 +1437,47 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
return _u return _u
} }
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
_u.mutation.SetClaudeCodeOnly(v)
return _u
}
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableClaudeCodeOnly(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetClaudeCodeOnly(*v)
}
return _u
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (_u *GroupUpdateOne) SetFallbackGroupID(v int64) *GroupUpdateOne {
_u.mutation.ResetFallbackGroupID()
_u.mutation.SetFallbackGroupID(v)
return _u
}
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableFallbackGroupID(v *int64) *GroupUpdateOne {
if v != nil {
_u.SetFallbackGroupID(*v)
}
return _u
}
// AddFallbackGroupID adds value to the "fallback_group_id" field.
func (_u *GroupUpdateOne) AddFallbackGroupID(v int64) *GroupUpdateOne {
_u.mutation.AddFallbackGroupID(v)
return _u
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
_u.mutation.ClearFallbackGroupID()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
...@@ -1810,6 +1904,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) ...@@ -1810,6 +1904,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.ImagePrice4kCleared() { if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
} }
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
if value, ok := _u.mutation.FallbackGroupID(); ok {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
......
...@@ -221,6 +221,8 @@ var ( ...@@ -221,6 +221,8 @@ var (
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
} }
// GroupsTable holds the schema information for the "groups" table. // GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{ GroupsTable = &schema.Table{
......
...@@ -3590,6 +3590,9 @@ type GroupMutation struct { ...@@ -3590,6 +3590,9 @@ type GroupMutation struct {
addimage_price_2k *float64 addimage_price_2k *float64
image_price_4k *float64 image_price_4k *float64
addimage_price_4k *float64 addimage_price_4k *float64
claude_code_only *bool
fallback_group_id *int64
addfallback_group_id *int64
clearedFields map[string]struct{} clearedFields map[string]struct{}
api_keys map[int64]struct{} api_keys map[int64]struct{}
removedapi_keys map[int64]struct{} removedapi_keys map[int64]struct{}
...@@ -4594,6 +4597,112 @@ func (m *GroupMutation) ResetImagePrice4k() { ...@@ -4594,6 +4597,112 @@ func (m *GroupMutation) ResetImagePrice4k() {
delete(m.clearedFields, group.FieldImagePrice4k) delete(m.clearedFields, group.FieldImagePrice4k)
} }
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
m.claude_code_only = &b
}
// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation.
func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) {
v := m.claude_code_only
if v == nil {
return
}
return *v, true
}
// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err)
}
return oldValue.ClaudeCodeOnly, nil
}
// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field.
func (m *GroupMutation) ResetClaudeCodeOnly() {
m.claude_code_only = nil
}
// SetFallbackGroupID sets the "fallback_group_id" field.
func (m *GroupMutation) SetFallbackGroupID(i int64) {
m.fallback_group_id = &i
m.addfallback_group_id = nil
}
// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation.
func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) {
v := m.fallback_group_id
if v == nil {
return
}
return *v, true
}
// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldFallbackGroupID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err)
}
return oldValue.FallbackGroupID, nil
}
// AddFallbackGroupID adds i to the "fallback_group_id" field.
func (m *GroupMutation) AddFallbackGroupID(i int64) {
if m.addfallback_group_id != nil {
*m.addfallback_group_id += i
} else {
m.addfallback_group_id = &i
}
}
// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation.
func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) {
v := m.addfallback_group_id
if v == nil {
return
}
return *v, true
}
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
func (m *GroupMutation) ClearFallbackGroupID() {
m.fallback_group_id = nil
m.addfallback_group_id = nil
m.clearedFields[group.FieldFallbackGroupID] = struct{}{}
}
// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation.
func (m *GroupMutation) FallbackGroupIDCleared() bool {
_, ok := m.clearedFields[group.FieldFallbackGroupID]
return ok
}
// ResetFallbackGroupID resets all changes to the "fallback_group_id" field.
func (m *GroupMutation) ResetFallbackGroupID() {
m.fallback_group_id = nil
m.addfallback_group_id = nil
delete(m.clearedFields, group.FieldFallbackGroupID)
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil { if m.api_keys == nil {
...@@ -4952,7 +5061,7 @@ func (m *GroupMutation) Type() string { ...@@ -4952,7 +5061,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *GroupMutation) Fields() []string { func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 17) fields := make([]string, 0, 19)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt) fields = append(fields, group.FieldCreatedAt)
} }
...@@ -5004,6 +5113,12 @@ func (m *GroupMutation) Fields() []string { ...@@ -5004,6 +5113,12 @@ func (m *GroupMutation) Fields() []string {
if m.image_price_4k != nil { if m.image_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k) fields = append(fields, group.FieldImagePrice4k)
} }
if m.claude_code_only != nil {
fields = append(fields, group.FieldClaudeCodeOnly)
}
if m.fallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
return fields return fields
} }
...@@ -5046,6 +5161,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { ...@@ -5046,6 +5161,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.ImagePrice2k() return m.ImagePrice2k()
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
return m.ImagePrice4k() return m.ImagePrice4k()
case group.FieldClaudeCodeOnly:
return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID:
return m.FallbackGroupID()
} }
return nil, false return nil, false
} }
...@@ -5089,6 +5208,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e ...@@ -5089,6 +5208,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldImagePrice2k(ctx) return m.OldImagePrice2k(ctx)
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
return m.OldImagePrice4k(ctx) return m.OldImagePrice4k(ctx)
case group.FieldClaudeCodeOnly:
return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID:
return m.OldFallbackGroupID(ctx)
} }
return nil, fmt.Errorf("unknown Group field %s", name) return nil, fmt.Errorf("unknown Group field %s", name)
} }
...@@ -5217,6 +5340,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { ...@@ -5217,6 +5340,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
} }
m.SetImagePrice4k(v) m.SetImagePrice4k(v)
return nil return nil
case group.FieldClaudeCodeOnly:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetClaudeCodeOnly(v)
return nil
case group.FieldFallbackGroupID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetFallbackGroupID(v)
return nil
} }
return fmt.Errorf("unknown Group field %s", name) return fmt.Errorf("unknown Group field %s", name)
} }
...@@ -5249,6 +5386,9 @@ func (m *GroupMutation) AddedFields() []string { ...@@ -5249,6 +5386,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.addimage_price_4k != nil { if m.addimage_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k) fields = append(fields, group.FieldImagePrice4k)
} }
if m.addfallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
return fields return fields
} }
...@@ -5273,6 +5413,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { ...@@ -5273,6 +5413,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedImagePrice2k() return m.AddedImagePrice2k()
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
return m.AddedImagePrice4k() return m.AddedImagePrice4k()
case group.FieldFallbackGroupID:
return m.AddedFallbackGroupID()
} }
return nil, false return nil, false
} }
...@@ -5338,6 +5480,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { ...@@ -5338,6 +5480,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
} }
m.AddImagePrice4k(v) m.AddImagePrice4k(v)
return nil return nil
case group.FieldFallbackGroupID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddFallbackGroupID(v)
return nil
} }
return fmt.Errorf("unknown Group numeric field %s", name) return fmt.Errorf("unknown Group numeric field %s", name)
} }
...@@ -5370,6 +5519,9 @@ func (m *GroupMutation) ClearedFields() []string { ...@@ -5370,6 +5519,9 @@ func (m *GroupMutation) ClearedFields() []string {
if m.FieldCleared(group.FieldImagePrice4k) { if m.FieldCleared(group.FieldImagePrice4k) {
fields = append(fields, group.FieldImagePrice4k) fields = append(fields, group.FieldImagePrice4k)
} }
if m.FieldCleared(group.FieldFallbackGroupID) {
fields = append(fields, group.FieldFallbackGroupID)
}
return fields return fields
} }
...@@ -5408,6 +5560,9 @@ func (m *GroupMutation) ClearField(name string) error { ...@@ -5408,6 +5560,9 @@ func (m *GroupMutation) ClearField(name string) error {
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
m.ClearImagePrice4k() m.ClearImagePrice4k()
return nil return nil
case group.FieldFallbackGroupID:
m.ClearFallbackGroupID()
return nil
} }
return fmt.Errorf("unknown Group nullable field %s", name) return fmt.Errorf("unknown Group nullable field %s", name)
} }
...@@ -5467,6 +5622,12 @@ func (m *GroupMutation) ResetField(name string) error { ...@@ -5467,6 +5622,12 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldImagePrice4k: case group.FieldImagePrice4k:
m.ResetImagePrice4k() m.ResetImagePrice4k()
return nil return nil
case group.FieldClaudeCodeOnly:
m.ResetClaudeCodeOnly()
return nil
case group.FieldFallbackGroupID:
m.ResetFallbackGroupID()
return nil
} }
return fmt.Errorf("unknown Group field %s", name) return fmt.Errorf("unknown Group field %s", name)
} }
......
...@@ -270,6 +270,10 @@ func init() { ...@@ -270,6 +270,10 @@ func init() {
groupDescDefaultValidityDays := groupFields[10].Descriptor() groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
proxyMixin := schema.Proxy{}.Mixin() proxyMixin := schema.Proxy{}.Mixin()
proxyMixinHooks1 := proxyMixin[1].Hooks() proxyMixinHooks1 := proxyMixin[1].Hooks()
proxy.Hooks[0] = proxyMixinHooks1[0] proxy.Hooks[0] = proxyMixinHooks1[0]
......
...@@ -86,6 +86,15 @@ func (Group) Fields() []ent.Field { ...@@ -86,6 +86,15 @@ func (Group) Fields() []ent.Field {
Optional(). Optional().
Nillable(). Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).
Comment("是否仅允许 Claude Code 客户端"),
field.Int64("fallback_group_id").
Optional().
Nillable().
Comment("非 Claude Code 请求降级使用的分组 ID"),
} }
} }
...@@ -101,6 +110,8 @@ func (Group) Edges() []ent.Edge { ...@@ -101,6 +110,8 @@ func (Group) Edges() []ent.Edge {
edge.From("allowed_users", User.Type). edge.From("allowed_users", User.Type).
Ref("allowed_groups"). Ref("allowed_groups").
Through("user_allowed_groups", UserAllowedGroup.Type), Through("user_allowed_groups", UserAllowedGroup.Type),
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
// 这样允许多个分组指向同一个降级分组(M2O 关系)
} }
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"log" "log"
"net/url"
"os" "os"
"strings" "strings"
"time" "time"
...@@ -35,24 +36,25 @@ const ( ...@@ -35,24 +36,25 @@ const (
) )
type Config struct { type Config struct {
Server ServerConfig `mapstructure:"server"` Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"` CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"` Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"` Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"` Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"` Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"` Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"` JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"` LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"` Default DefaultConfig `mapstructure:"default"`
Pricing PricingConfig `mapstructure:"pricing"` RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Gateway GatewayConfig `mapstructure:"gateway"` Pricing PricingConfig `mapstructure:"pricing"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"` Gateway GatewayConfig `mapstructure:"gateway"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Gemini GeminiConfig `mapstructure:"gemini"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Update UpdateConfig `mapstructure:"update"` Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
} }
// UpdateConfig 在线更新相关配置 // UpdateConfig 在线更新相关配置
...@@ -322,6 +324,30 @@ type TurnstileConfig struct { ...@@ -322,6 +324,30 @@ type TurnstileConfig struct {
Required bool `mapstructure:"required"` Required bool `mapstructure:"required"`
} }
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
//
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
// 这里是用于登录 Sub2API 本身的用户体系。
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type DefaultConfig struct { type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"` AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"` AdminPassword string `mapstructure:"admin_password"`
...@@ -388,6 +414,18 @@ func Load() (*Config, error) { ...@@ -388,6 +414,18 @@ func Load() (*Config, error) {
cfg.Server.Mode = "debug" cfg.Server.Mode = "debug"
} }
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
...@@ -426,6 +464,81 @@ func Load() (*Config, error) { ...@@ -426,6 +464,81 @@ func Load() (*Config, error) {
return &cfg, nil return &cfg, nil
} }
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
func ValidateAbsoluteHTTPURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
// ValidateFrontendRedirectURL 校验前端回调地址:
// - 允许同源相对路径(以 / 开头)
// - 或绝对 http(s) URL(禁止 fragment)
func ValidateFrontendRedirectURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
if strings.ContainsAny(raw, "\r\n") {
return fmt.Errorf("contains invalid characters")
}
if strings.HasPrefix(raw, "/") {
if strings.HasPrefix(raw, "//") {
return fmt.Errorf("must not start with //")
}
return nil
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute http(s) url or relative path")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
}
func warnIfInsecureURL(field, raw string) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
}
}
func setDefaults() { func setDefaults() {
viper.SetDefault("run_mode", RunModeStandard) viper.SetDefault("run_mode", RunModeStandard)
...@@ -475,6 +588,22 @@ func setDefaults() { ...@@ -475,6 +588,22 @@ func setDefaults() {
// Turnstile // Turnstile
viper.SetDefault("turnstile.required", false) viper.SetDefault("turnstile.required", false)
// LinuxDo Connect OAuth 登录(终端用户 SSO)
viper.SetDefault("linuxdo_connect.enabled", false)
viper.SetDefault("linuxdo_connect.client_id", "")
viper.SetDefault("linuxdo_connect.client_secret", "")
viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize")
viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token")
viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user")
viper.SetDefault("linuxdo_connect.scopes", "user")
viper.SetDefault("linuxdo_connect.redirect_url", "")
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
viper.SetDefault("linuxdo_connect.use_pkce", false)
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
// Database // Database
viper.SetDefault("database.host", "localhost") viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432) viper.SetDefault("database.port", 5432)
...@@ -586,6 +715,60 @@ func (c *Config) Validate() error { ...@@ -586,6 +715,60 @@ func (c *Config) Validate() error {
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
return fmt.Errorf("security.csp.policy is required when CSP is enabled") return fmt.Errorf("security.csp.policy is required when CSP is enabled")
} }
if c.LinuxDo.Enabled {
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" {
return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.TokenURL) == "" {
return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" {
return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" {
return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true")
}
method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic", "none":
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
}
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true")
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil {
return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err)
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil {
return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err)
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil {
return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err)
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil {
return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err)
}
if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil {
return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err)
}
warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL)
warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL)
warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL)
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.Billing.CircuitBreaker.Enabled { if c.Billing.CircuitBreaker.Enabled {
if c.Billing.CircuitBreaker.FailureThreshold <= 0 { if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
......
package config package config
import ( import (
"strings"
"testing" "testing"
"time" "time"
...@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { ...@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
t.Fatalf("ResponseHeaders.Enabled = true, want false") t.Fatalf("ResponseHeaders.Enabled = true, want false")
} }
} }
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "test-client"
cfg.LinuxDo.ClientSecret = "test-secret"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
cfg.LinuxDo.UsePKCE = false
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for javascript scheme, got nil")
}
if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") {
t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err)
}
}
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "test-client"
cfg.LinuxDo.ClientSecret = ""
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "none"
cfg.LinuxDo.UsePKCE = false
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
}
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
}
}
...@@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct { ...@@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct {
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"` Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"` Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"` Extra map[string]any `json:"extra"`
...@@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) {
accountType := c.Query("type") accountType := c.Query("type")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil { if err != nil {
...@@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { ...@@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.Concurrency != nil || req.Concurrency != nil ||
req.Priority != nil || req.Priority != nil ||
req.Status != "" || req.Status != "" ||
req.Schedulable != nil ||
req.GroupIDs != nil || req.GroupIDs != nil ||
len(req.Credentials) > 0 || len(req.Credentials) > 0 ||
len(req.Extra) > 0 len(req.Extra) > 0
...@@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { ...@@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Priority: req.Priority, Priority: req.Priority,
Status: req.Status, Status: req.Status,
Schedulable: req.Schedulable,
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
Credentials: req.Credentials, Credentials: req.Credentials,
Extra: req.Extra, Extra: req.Extra,
......
...@@ -2,6 +2,7 @@ package admin ...@@ -2,6 +2,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -34,9 +35,11 @@ type CreateGroupRequest struct { ...@@ -34,9 +35,11 @@ type CreateGroupRequest struct {
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
} }
// UpdateGroupRequest represents update group request // UpdateGroupRequest represents update group request
...@@ -52,9 +55,11 @@ type UpdateGroupRequest struct { ...@@ -52,9 +55,11 @@ type UpdateGroupRequest struct {
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
} }
// List handles listing all groups with pagination // List handles listing all groups with pagination
...@@ -63,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) { ...@@ -63,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
platform := c.Query("platform") platform := c.Query("platform")
status := c.Query("status") status := c.Query("status")
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
isExclusiveStr := c.Query("is_exclusive") isExclusiveStr := c.Query("is_exclusive")
var isExclusive *bool var isExclusive *bool
...@@ -71,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) { ...@@ -71,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) {
isExclusive = &val isExclusive = &val
} }
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive) groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -150,6 +161,8 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -150,6 +161,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -188,6 +201,8 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -188,6 +201,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
......
...@@ -51,16 +51,21 @@ func (h *ProxyHandler) List(c *gin.Context) { ...@@ -51,16 +51,21 @@ func (h *ProxyHandler) List(c *gin.Context) {
protocol := c.Query("protocol") protocol := c.Query("protocol")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search) proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
out := make([]dto.Proxy, 0, len(proxies)) out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies { for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i])) out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
} }
response.Paginated(c, out, total, page, pageSize) response.Paginated(c, out, total, page, pageSize)
} }
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"encoding/csv" "encoding/csv"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) { ...@@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) {
codeType := c.Query("type") codeType := c.Query("type")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search) codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil { if err != nil {
......
...@@ -2,8 +2,10 @@ package admin ...@@ -2,8 +2,10 @@ package admin
import ( import (
"log" "log"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -38,33 +40,37 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -38,33 +40,37 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
} }
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost, SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort, SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername, SMTPUsername: settings.SMTPUsername,
SMTPPasswordConfigured: settings.SMTPPasswordConfigured, SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
SMTPFrom: settings.SMTPFrom, SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName, SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS, SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled, TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
SiteName: settings.SiteName, LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
SiteLogo: settings.SiteLogo, LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
SiteSubtitle: settings.SiteSubtitle, LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
APIBaseURL: settings.APIBaseURL, LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
ContactInfo: settings.ContactInfo, SiteName: settings.SiteName,
DocURL: settings.DocURL, SiteLogo: settings.SiteLogo,
DefaultConcurrency: settings.DefaultConcurrency, SiteSubtitle: settings.SiteSubtitle,
DefaultBalance: settings.DefaultBalance, APIBaseURL: settings.APIBaseURL,
EnableModelFallback: settings.EnableModelFallback, ContactInfo: settings.ContactInfo,
FallbackModelAnthropic: settings.FallbackModelAnthropic, DocURL: settings.DocURL,
FallbackModelOpenAI: settings.FallbackModelOpenAI, DefaultConcurrency: settings.DefaultConcurrency,
FallbackModelGemini: settings.FallbackModelGemini, DefaultBalance: settings.DefaultBalance,
FallbackModelAntigravity: settings.FallbackModelAntigravity, EnableModelFallback: settings.EnableModelFallback,
EnableIdentityPatch: settings.EnableIdentityPatch, FallbackModelAnthropic: settings.FallbackModelAnthropic,
IdentityPatchPrompt: settings.IdentityPatchPrompt, FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
}) })
} }
...@@ -88,6 +94,12 @@ type UpdateSettingsRequest struct { ...@@ -88,6 +94,12 @@ type UpdateSettingsRequest struct {
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key"` TurnstileSecretKey string `json:"turnstile_secret_key"`
// LinuxDo Connect OAuth 登录(终端用户 SSO)
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置 // OEM设置
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
...@@ -165,34 +177,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -165,34 +177,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
} }
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret)
req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL)
if req.LinuxDoConnectClientID == "" {
response.BadRequest(c, "LinuxDo Client ID is required when enabled")
return
}
if req.LinuxDoConnectRedirectURL == "" {
response.BadRequest(c, "LinuxDo Redirect URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil {
response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL")
return
}
// 如果未提供 client_secret,则保留现有值(如有)。
if req.LinuxDoConnectClientSecret == "" {
if previousSettings.LinuxDoConnectClientSecret == "" {
response.BadRequest(c, "LinuxDo Client Secret is required when enabled")
return
}
req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret
}
}
settings := &service.SystemSettings{ settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled, RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
SMTPHost: req.SMTPHost, SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort, SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername, SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword, SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom, SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName, SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS, SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled, TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey, TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey, TurnstileSecretKey: req.TurnstileSecretKey,
SiteName: req.SiteName, LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
SiteLogo: req.SiteLogo, LinuxDoConnectClientID: req.LinuxDoConnectClientID,
SiteSubtitle: req.SiteSubtitle, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
APIBaseURL: req.APIBaseURL, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
ContactInfo: req.ContactInfo, SiteName: req.SiteName,
DocURL: req.DocURL, SiteLogo: req.SiteLogo,
DefaultConcurrency: req.DefaultConcurrency, SiteSubtitle: req.SiteSubtitle,
DefaultBalance: req.DefaultBalance, APIBaseURL: req.APIBaseURL,
EnableModelFallback: req.EnableModelFallback, ContactInfo: req.ContactInfo,
FallbackModelAnthropic: req.FallbackModelAnthropic, DocURL: req.DocURL,
FallbackModelOpenAI: req.FallbackModelOpenAI, DefaultConcurrency: req.DefaultConcurrency,
FallbackModelGemini: req.FallbackModelGemini, DefaultBalance: req.DefaultBalance,
FallbackModelAntigravity: req.FallbackModelAntigravity, EnableModelFallback: req.EnableModelFallback,
EnableIdentityPatch: req.EnableIdentityPatch, FallbackModelAnthropic: req.FallbackModelAnthropic,
IdentityPatchPrompt: req.IdentityPatchPrompt, FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
} }
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
...@@ -210,33 +255,37 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -210,33 +255,37 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled, RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost, SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort, SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername, SMTPUsername: updatedSettings.SMTPUsername,
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
SMTPFrom: updatedSettings.SMTPFrom, SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName, SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS, SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled, TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey, TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
SiteName: updatedSettings.SiteName, LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
SiteLogo: updatedSettings.SiteLogo, LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
SiteSubtitle: updatedSettings.SiteSubtitle, LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
APIBaseURL: updatedSettings.APIBaseURL, LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
ContactInfo: updatedSettings.ContactInfo, SiteName: updatedSettings.SiteName,
DocURL: updatedSettings.DocURL, SiteLogo: updatedSettings.SiteLogo,
DefaultConcurrency: updatedSettings.DefaultConcurrency, SiteSubtitle: updatedSettings.SiteSubtitle,
DefaultBalance: updatedSettings.DefaultBalance, APIBaseURL: updatedSettings.APIBaseURL,
EnableModelFallback: updatedSettings.EnableModelFallback, ContactInfo: updatedSettings.ContactInfo,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, DocURL: updatedSettings.DocURL,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, DefaultConcurrency: updatedSettings.DefaultConcurrency,
FallbackModelGemini: updatedSettings.FallbackModelGemini, DefaultBalance: updatedSettings.DefaultBalance,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, EnableModelFallback: updatedSettings.EnableModelFallback,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
}) })
} }
...@@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if req.TurnstileSecretKey != "" { if req.TurnstileSecretKey != "" {
changed = append(changed, "turnstile_secret_key") changed = append(changed, "turnstile_secret_key")
} }
if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled {
changed = append(changed, "linuxdo_connect_enabled")
}
if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID {
changed = append(changed, "linuxdo_connect_client_id")
}
if req.LinuxDoConnectClientSecret != "" {
changed = append(changed, "linuxdo_connect_client_secret")
}
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
if before.SiteName != after.SiteName { if before.SiteName != after.SiteName {
changed = append(changed, "site_name") changed = append(changed, "site_name")
} }
...@@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.FallbackModelAntigravity != after.FallbackModelAntigravity { if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
changed = append(changed, "fallback_model_antigravity") changed = append(changed, "fallback_model_antigravity")
} }
if before.EnableIdentityPatch != after.EnableIdentityPatch {
changed = append(changed, "enable_identity_patch")
}
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
changed = append(changed, "identity_patch_prompt")
}
return changed return changed
} }
......
...@@ -2,6 +2,7 @@ package admin ...@@ -2,6 +2,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct { ...@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct {
func (h *UserHandler) List(c *gin.Context) { func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
filters := service.UserListFilters{ filters := service.UserListFilters{
Status: c.Query("status"), Status: c.Query("status"),
Role: c.Query("role"), Role: c.Query("role"),
Search: c.Query("search"), Search: search,
Attributes: parseAttributeFilters(c), Attributes: parseAttributeFilters(c),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment