Commit 61f55674 authored by yangjianbo's avatar yangjianbo
Browse files
parents eeb1282f 7d1fe818
# Linux DO Connect
OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。
目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。
## 基本介绍
这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。
- 可获取字段:
| 参数 | 说明 |
| ----------------- | ------------------------------- |
| `id` | 用户唯一标识(不可变) |
| `username` | 论坛用户名 |
| `name` | 论坛用户昵称(可变) |
| `avatar_template` | 用户头像模板URL(支持多种尺寸) |
| `active` | 账号活跃状态 |
| `trust_level` | 信任等级(0-4) |
| `silenced` | 禁言状态 |
| `external_ids` | 外部ID关联信息 |
| `api_key` | API访问密钥 |
通过这些信息,公益网站/接口可以实现:
1. 基于 `id` 的服务频率限制
2. 基于 `trust_level` 的服务额度分配
3. 基于用户信息的滥用举报机制
## 相关端点
- Authorize 端点: `https://connect.linux.do/oauth2/authorize`
- Token 端点:`https://connect.linux.do/oauth2/token`
- 用户信息 端点:`https://connect.linux.do/api/user`
## 申请使用
- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。
![linuxdoconnect_1](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_1.png&w=1080&q=75)
- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。
![linuxdoconnect_2](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_2.png&w=1080&q=75)
- 申请成功后,你将获得 **`Client Id`****`Client Secret`**,这是你应用的唯一身份凭证。
![linuxdoconnect_3](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_3.png&w=1080&q=75)
## 接入 Linux Do
JavaScript
```JavaScript
// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios
// npm install axios
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
const axios = require('axios');
const readline = require('readline');
// 配置信息(建议通过环境变量配置,避免使用硬编码)
const CLIENT_ID = '你的 Client ID';
const CLIENT_SECRET = '你的 Client Secret';
const REDIRECT_URI = '你的回调地址';
const AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
const TOKEN_URL = 'https://connect.linux.do/oauth2/token';
const USER_INFO_URL = 'https://connect.linux.do/api/user';
// 第一步:生成授权 URL
function getAuthUrl() {
const params = new URLSearchParams({
client_id: CLIENT_ID,
redirect_uri: REDIRECT_URI,
response_type: 'code',
scope: 'user'
});
return `${AUTH_URL}?${params.toString()}`;
}
// 第二步:获取 code 参数
function getCode() {
return new Promise((resolve) => {
// 本例中使用终端输入来模拟流程,仅供本地测试
// 请在实际应用中替换为真实的处理逻辑
const rl = readline.createInterface({ input: process.stdin, output: process.stdout });
rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => {
rl.close();
resolve(answer.trim());
});
});
}
// 第三步:使用 code 参数获取访问令牌
async function getAccessToken(code) {
try {
const form = new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code: code,
redirect_uri: REDIRECT_URI,
grant_type: 'authorization_code'
}).toString();
const response = await axios.post(TOKEN_URL, form, {
// 提醒:需正确配置请求头,否则无法正常获取访问令牌
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
});
return response.data;
} catch (error) {
console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
throw error;
}
}
// 第四步:使用访问令牌获取用户信息
async function getUserInfo(accessToken) {
try {
const response = await axios.get(USER_INFO_URL, {
headers: {
Authorization: `Bearer ${accessToken}`
}
});
return response.data;
} catch (error) {
console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`);
throw error;
}
}
// 主流程
async function main() {
// 1. 生成授权 URL,前端引导用户访问授权页
const authUrl = getAuthUrl();
console.log(`请访问此 URL 授权:${authUrl}
`);
// 2. 用户授权后,从回调 URL 获取 code 参数
const code = await getCode();
try {
// 3. 使用 code 参数获取访问令牌
const tokenData = await getAccessToken(code);
const accessToken = tokenData.access_token;
// 4. 使用访问令牌获取用户信息
if (accessToken) {
const userInfo = await getUserInfo(accessToken);
console.log(`
获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`);
} else {
console.log(`
获取访问令牌失败:${JSON.stringify(tokenData)}`);
}
} catch (error) {
console.error('发生错误:', error);
}
}
```
Python
```python
# 安装第三方请求库,本例中使用 requests
# pip install requests
# 通过 OAuth2 获取 Linux Do 用户信息的参考流程
import requests
import json
# 配置信息(建议通过环境变量配置,避免使用硬编码)
CLIENT_ID = '你的 Client ID'
CLIENT_SECRET = '你的 Client Secret'
REDIRECT_URI = '你的回调地址'
AUTH_URL = 'https://connect.linux.do/oauth2/authorize'
TOKEN_URL = 'https://connect.linux.do/oauth2/token'
USER_INFO_URL = 'https://connect.linux.do/api/user'
# 第一步:生成授权 URL
def get_auth_url():
params = {
'client_id': CLIENT_ID,
'redirect_uri': REDIRECT_URI,
'response_type': 'code',
'scope': 'user'
}
auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}"
return auth_url
# 第二步:获取 code 参数
def get_code():
# 本例中使用终端输入来模拟流程,仅供本地测试
# 请在实际应用中替换为真实的处理逻辑
return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip()
# 第三步:使用 code 参数获取访问令牌
def get_access_token(code):
try:
data = {
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'code': code,
'redirect_uri': REDIRECT_URI,
'grant_type': 'authorization_code'
}
# 提醒:需正确配置请求头,否则无法正常获取访问令牌
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
response = requests.post(TOKEN_URL, data=data, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"获取访问令牌失败:{e}")
return None
# 第四步:使用访问令牌获取用户信息
def get_user_info(access_token):
try:
headers = {
'Authorization': f'Bearer {access_token}'
}
response = requests.get(USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"获取用户信息失败:{e}")
return None
# 主流程
if __name__ == '__main__':
# 1. 生成授权 URL,前端引导用户访问授权页
auth_url = get_auth_url()
print(f'请访问此 URL 授权:{auth_url}
')
# 2. 用户授权后,从回调 URL 获取 code 参数
code = get_code()
# 3. 使用 code 参数获取访问令牌
token_data = get_access_token(code)
if token_data:
access_token = token_data.get('access_token')
# 4. 使用访问令牌获取用户信息
if access_token:
user_info = get_user_info(access_token)
if user_info:
print(f"
获取用户信息成功:{json.dumps(user_info, indent=2)}")
else:
print("
获取用户信息失败")
else:
print(f"
获取访问令牌失败:{json.dumps(token_data, indent=2)}")
else:
print("
获取访问令牌失败")
```
PHP
```php
// 通过 OAuth2 获取 Linux Do 用户信息的参考流程
// 配置信息
$CLIENT_ID = '你的 Client ID';
$CLIENT_SECRET = '你的 Client Secret';
$REDIRECT_URI = '你的回调地址';
$AUTH_URL = 'https://connect.linux.do/oauth2/authorize';
$TOKEN_URL = 'https://connect.linux.do/oauth2/token';
$USER_INFO_URL = 'https://connect.linux.do/api/user';
// 生成授权 URL
function getAuthUrl($clientId, $redirectUri) {
global $AUTH_URL;
return $AUTH_URL . '?' . http_build_query([
'client_id' => $clientId,
'redirect_uri' => $redirectUri,
'response_type' => 'code',
'scope' => 'user'
]);
}
// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤)
function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) {
global $TOKEN_URL, $USER_INFO_URL;
// 1. 获取访问令牌
$ch = curl_init($TOKEN_URL);
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
curl_setopt($ch, CURLOPT_POST, true);
curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([
'client_id' => $clientId,
'client_secret' => $clientSecret,
'code' => $code,
'redirect_uri' => $redirectUri,
'grant_type' => 'authorization_code'
]));
curl_setopt($ch, CURLOPT_HTTPHEADER, [
'Content-Type: application/x-www-form-urlencoded',
'Accept: application/json'
]);
$tokenResponse = curl_exec($ch);
curl_close($ch);
$tokenData = json_decode($tokenResponse, true);
if (!isset($tokenData['access_token'])) {
return ['error' => '获取访问令牌失败', 'details' => $tokenData];
}
// 2. 获取用户信息
$ch = curl_init($USER_INFO_URL);
curl_setopt($ch, CURLOPT_RETURNTRANSFER, true);
curl_setopt($ch, CURLOPT_HTTPHEADER, [
'Authorization: Bearer ' . $tokenData['access_token']
]);
$userResponse = curl_exec($ch);
curl_close($ch);
return json_decode($userResponse, true);
}
// 主流程
// 1. 生成授权 URL
$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI);
echo "<a href='$authUrl'>使用 Linux Do 登录</a>";
// 2. 处理回调并获取用户信息
if (isset($_GET['code'])) {
$userInfo = getUserInfoWithCode(
$_GET['code'],
$CLIENT_ID,
$CLIENT_SECRET,
$REDIRECT_URI
);
if (isset($userInfo['error'])) {
echo '错误: ' . $userInfo['error'];
} else {
echo '欢迎, ' . $userInfo['name'] . '!';
// 处理用户登录逻辑...
}
}
```
## 使用说明
### 授权流程
1. 用户点击应用中的’使用 Linux Do 登录’按钮
2. 系统将用户重定向至 Linux Do 的授权页面
3. 用户完成授权后,系统自动重定向回应用并携带授权码
4. 应用使用授权码获取访问令牌
5. 使用访问令牌获取用户信息
### 安全建议
- 切勿在前端代码中暴露 Client Secret
- 对所有用户输入数据进行严格验证
- 确保使用 HTTPS 协议传输数据
- 定期更新并妥善保管 Client Secret
\ No newline at end of file
...@@ -53,7 +53,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -53,7 +53,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
emailQueueService := service.ProvideEmailQueueService(emailService) emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
userService := service.NewUserService(userRepository) userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewAPIKeyRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db) groupRepository := repository.NewGroupRepository(client, db)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
package ent package ent
import ( import (
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"time" "time"
...@@ -35,6 +36,10 @@ type APIKey struct { ...@@ -35,6 +36,10 @@ type APIKey struct {
GroupID *int64 `json:"group_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"`
// Status holds the value of the "status" field. // Status holds the value of the "status" field.
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
// Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]
IPWhitelist []string `json:"ip_whitelist,omitempty"`
// Blocked IPs/CIDRs
IPBlacklist []string `json:"ip_blacklist,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the APIKeyQuery when eager-loading is set. // The values are being populated by the APIKeyQuery when eager-loading is set.
Edges APIKeyEdges `json:"edges"` Edges APIKeyEdges `json:"edges"`
...@@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { ...@@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
values[i] = new([]byte)
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
...@@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { ...@@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.Status = value.String _m.Status = value.String
} }
case apikey.FieldIPWhitelist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.IPWhitelist); err != nil {
return fmt.Errorf("unmarshal field ip_whitelist: %w", err)
}
}
case apikey.FieldIPBlacklist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_blacklist", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.IPBlacklist); err != nil {
return fmt.Errorf("unmarshal field ip_blacklist: %w", err)
}
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
...@@ -245,6 +268,12 @@ func (_m *APIKey) String() string { ...@@ -245,6 +268,12 @@ func (_m *APIKey) String() string {
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("status=") builder.WriteString("status=")
builder.WriteString(_m.Status) builder.WriteString(_m.Status)
builder.WriteString(", ")
builder.WriteString("ip_whitelist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist))
builder.WriteString(", ")
builder.WriteString("ip_blacklist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist))
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }
......
...@@ -31,6 +31,10 @@ const ( ...@@ -31,6 +31,10 @@ const (
FieldGroupID = "group_id" FieldGroupID = "group_id"
// FieldStatus holds the string denoting the status field in the database. // FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status" FieldStatus = "status"
// FieldIPWhitelist holds the string denoting the ip_whitelist field in the database.
FieldIPWhitelist = "ip_whitelist"
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
FieldIPBlacklist = "ip_blacklist"
// EdgeUser holds the string denoting the user edge name in mutations. // EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user" EdgeUser = "user"
// EdgeGroup holds the string denoting the group edge name in mutations. // EdgeGroup holds the string denoting the group edge name in mutations.
...@@ -73,6 +77,8 @@ var Columns = []string{ ...@@ -73,6 +77,8 @@ var Columns = []string{
FieldName, FieldName,
FieldGroupID, FieldGroupID,
FieldStatus, FieldStatus,
FieldIPWhitelist,
FieldIPBlacklist,
} }
// ValidColumn reports if the column name is valid (part of the table columns). // ValidColumn reports if the column name is valid (part of the table columns).
......
...@@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey { ...@@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey {
return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v)) return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v))
} }
// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field.
func IPWhitelistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist))
}
// IPWhitelistNotNil applies the NotNil predicate on the "ip_whitelist" field.
func IPWhitelistNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldIPWhitelist))
}
// IPBlacklistIsNil applies the IsNil predicate on the "ip_blacklist" field.
func IPBlacklistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPBlacklist))
}
// IPBlacklistNotNil applies the NotNil predicate on the "ip_blacklist" field.
func IPBlacklistNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist))
}
// HasUser applies the HasEdge predicate on the "user" edge. // HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.APIKey { func HasUser() predicate.APIKey {
return predicate.APIKey(func(s *sql.Selector) { return predicate.APIKey(func(s *sql.Selector) {
......
...@@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate { ...@@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate {
return _c return _c
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate {
_c.mutation.SetIPWhitelist(v)
return _c
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate {
_c.mutation.SetIPBlacklist(v)
return _c
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
return _c.SetUserID(v.ID) return _c.SetUserID(v.ID)
...@@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { ...@@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
_spec.SetField(apikey.FieldStatus, field.TypeString, value) _spec.SetField(apikey.FieldStatus, field.TypeString, value)
_node.Status = value _node.Status = value
} }
if value, ok := _c.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
_node.IPWhitelist = value
}
if value, ok := _c.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
_node.IPBlacklist = value
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
...@@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert { ...@@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert {
return u return u
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPWhitelist, v)
return u
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateIPWhitelist() *APIKeyUpsert {
u.SetExcluded(apikey.FieldIPWhitelist)
return u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsert) ClearIPWhitelist() *APIKeyUpsert {
u.SetNull(apikey.FieldIPWhitelist)
return u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsert) SetIPBlacklist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPBlacklist, v)
return u
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateIPBlacklist() *APIKeyUpsert {
u.SetExcluded(apikey.FieldIPBlacklist)
return u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert {
u.SetNull(apikey.FieldIPBlacklist)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
...@@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne { ...@@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne {
}) })
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPWhitelist(v)
})
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateIPWhitelist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPWhitelist()
})
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsertOne) ClearIPWhitelist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPWhitelist()
})
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsertOne) SetIPBlacklist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPBlacklist(v)
})
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateIPBlacklist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPBlacklist()
})
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPBlacklist()
})
}
// Exec executes the query. // Exec executes the query.
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
...@@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk { ...@@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk {
}) })
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPWhitelist(v)
})
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateIPWhitelist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPWhitelist()
})
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) ClearIPWhitelist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPWhitelist()
})
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsertBulk) SetIPBlacklist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPBlacklist(v)
})
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateIPBlacklist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPBlacklist()
})
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPBlacklist()
})
}
// Exec executes the query. // Exec executes the query.
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
...@@ -133,6 +134,42 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate { ...@@ -133,6 +134,42 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate {
return _u return _u
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate {
_u.mutation.SetIPWhitelist(v)
return _u
}
// AppendIPWhitelist appends value to the "ip_whitelist" field.
func (_u *APIKeyUpdate) AppendIPWhitelist(v []string) *APIKeyUpdate {
_u.mutation.AppendIPWhitelist(v)
return _u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (_u *APIKeyUpdate) ClearIPWhitelist() *APIKeyUpdate {
_u.mutation.ClearIPWhitelist()
return _u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_u *APIKeyUpdate) SetIPBlacklist(v []string) *APIKeyUpdate {
_u.mutation.SetIPBlacklist(v)
return _u
}
// AppendIPBlacklist appends value to the "ip_blacklist" field.
func (_u *APIKeyUpdate) AppendIPBlacklist(v []string) *APIKeyUpdate {
_u.mutation.AppendIPBlacklist(v)
return _u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate {
_u.mutation.ClearIPBlacklist()
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
...@@ -291,6 +328,28 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -291,6 +328,28 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Status(); ok { if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value) _spec.SetField(apikey.FieldStatus, field.TypeString, value)
} }
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPWhitelist, value)
})
}
if _u.mutation.IPWhitelistCleared() {
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
}
if value, ok := _u.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPBlacklist, value)
})
}
if _u.mutation.IPBlacklistCleared() {
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
...@@ -516,6 +575,42 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne { ...@@ -516,6 +575,42 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne {
return _u return _u
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne {
_u.mutation.SetIPWhitelist(v)
return _u
}
// AppendIPWhitelist appends value to the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) AppendIPWhitelist(v []string) *APIKeyUpdateOne {
_u.mutation.AppendIPWhitelist(v)
return _u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) ClearIPWhitelist() *APIKeyUpdateOne {
_u.mutation.ClearIPWhitelist()
return _u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_u *APIKeyUpdateOne) SetIPBlacklist(v []string) *APIKeyUpdateOne {
_u.mutation.SetIPBlacklist(v)
return _u
}
// AppendIPBlacklist appends value to the "ip_blacklist" field.
func (_u *APIKeyUpdateOne) AppendIPBlacklist(v []string) *APIKeyUpdateOne {
_u.mutation.AppendIPBlacklist(v)
return _u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne {
_u.mutation.ClearIPBlacklist()
return _u
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
return _u.SetUserID(v.ID) return _u.SetUserID(v.ID)
...@@ -704,6 +799,28 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro ...@@ -704,6 +799,28 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
if value, ok := _u.mutation.Status(); ok { if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value) _spec.SetField(apikey.FieldStatus, field.TypeString, value)
} }
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPWhitelist, value)
})
}
if _u.mutation.IPWhitelistCleared() {
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
}
if value, ok := _u.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPBlacklist, value)
})
}
if _u.mutation.IPBlacklistCleared() {
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
}
if _u.mutation.UserCleared() { if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
......
...@@ -18,6 +18,8 @@ var ( ...@@ -18,6 +18,8 @@ var (
{Name: "key", Type: field.TypeString, Unique: true, Size: 128}, {Name: "key", Type: field.TypeString, Unique: true, Size: 128},
{Name: "name", Type: field.TypeString, Size: 100}, {Name: "name", Type: field.TypeString, Size: 100},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
{Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true},
{Name: "user_id", Type: field.TypeInt64}, {Name: "user_id", Type: field.TypeInt64},
} }
...@@ -29,13 +31,13 @@ var ( ...@@ -29,13 +31,13 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "api_keys_groups_api_keys", Symbol: "api_keys_groups_api_keys",
Columns: []*schema.Column{APIKeysColumns[7]}, Columns: []*schema.Column{APIKeysColumns[9]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "api_keys_users_api_keys", Symbol: "api_keys_users_api_keys",
Columns: []*schema.Column{APIKeysColumns[8]}, Columns: []*schema.Column{APIKeysColumns[10]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
...@@ -44,12 +46,12 @@ var ( ...@@ -44,12 +46,12 @@ var (
{ {
Name: "apikey_user_id", Name: "apikey_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{APIKeysColumns[8]}, Columns: []*schema.Column{APIKeysColumns[10]},
}, },
{ {
Name: "apikey_group_id", Name: "apikey_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{APIKeysColumns[7]}, Columns: []*schema.Column{APIKeysColumns[9]},
}, },
{ {
Name: "apikey_status", Name: "apikey_status",
...@@ -376,6 +378,7 @@ var ( ...@@ -376,6 +378,7 @@ var (
{Name: "duration_ms", Type: field.TypeInt, Nullable: true}, {Name: "duration_ms", Type: field.TypeInt, Nullable: true},
{Name: "first_token_ms", Type: field.TypeInt, Nullable: true}, {Name: "first_token_ms", Type: field.TypeInt, Nullable: true},
{Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512}, {Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512},
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
...@@ -393,31 +396,31 @@ var ( ...@@ -393,31 +396,31 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "usage_logs_api_keys_usage_logs", Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[24]}, Columns: []*schema.Column{UsageLogsColumns[25]},
RefColumns: []*schema.Column{APIKeysColumns[0]}, RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_accounts_usage_logs", Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[25]}, Columns: []*schema.Column{UsageLogsColumns[26]},
RefColumns: []*schema.Column{AccountsColumns[0]}, RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_groups_usage_logs", Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[26]}, Columns: []*schema.Column{UsageLogsColumns[27]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "usage_logs_users_usage_logs", Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_user_subscriptions_usage_logs", Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
...@@ -426,32 +429,32 @@ var ( ...@@ -426,32 +429,32 @@ var (
{ {
Name: "usagelog_user_id", Name: "usagelog_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[28]},
}, },
{ {
Name: "usagelog_api_key_id", Name: "usagelog_api_key_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24]}, Columns: []*schema.Column{UsageLogsColumns[25]},
}, },
{ {
Name: "usagelog_account_id", Name: "usagelog_account_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25]}, Columns: []*schema.Column{UsageLogsColumns[26]},
}, },
{ {
Name: "usagelog_group_id", Name: "usagelog_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]}, Columns: []*schema.Column{UsageLogsColumns[27]},
}, },
{ {
Name: "usagelog_subscription_id", Name: "usagelog_subscription_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_created_at", Name: "usagelog_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[23]}, Columns: []*schema.Column{UsageLogsColumns[24]},
}, },
{ {
Name: "usagelog_model", Name: "usagelog_model",
...@@ -466,12 +469,12 @@ var ( ...@@ -466,12 +469,12 @@ var (
{ {
Name: "usagelog_user_id_created_at", Name: "usagelog_user_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[23]}, Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
}, },
{ {
Name: "usagelog_api_key_id_created_at", Name: "usagelog_api_key_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[23]}, Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
}, },
}, },
} }
......
...@@ -63,6 +63,10 @@ type APIKeyMutation struct { ...@@ -63,6 +63,10 @@ type APIKeyMutation struct {
key *string key *string
name *string name *string
status *string status *string
ip_whitelist *[]string
appendip_whitelist []string
ip_blacklist *[]string
appendip_blacklist []string
clearedFields map[string]struct{} clearedFields map[string]struct{}
user *int64 user *int64
cleareduser bool cleareduser bool
...@@ -488,6 +492,136 @@ func (m *APIKeyMutation) ResetStatus() { ...@@ -488,6 +492,136 @@ func (m *APIKeyMutation) ResetStatus() {
m.status = nil m.status = nil
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (m *APIKeyMutation) SetIPWhitelist(s []string) {
m.ip_whitelist = &s
m.appendip_whitelist = nil
}
// IPWhitelist returns the value of the "ip_whitelist" field in the mutation.
func (m *APIKeyMutation) IPWhitelist() (r []string, exists bool) {
v := m.ip_whitelist
if v == nil {
return
}
return *v, true
}
// OldIPWhitelist returns the old "ip_whitelist" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldIPWhitelist(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPWhitelist is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPWhitelist requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPWhitelist: %w", err)
}
return oldValue.IPWhitelist, nil
}
// AppendIPWhitelist adds s to the "ip_whitelist" field.
func (m *APIKeyMutation) AppendIPWhitelist(s []string) {
m.appendip_whitelist = append(m.appendip_whitelist, s...)
}
// AppendedIPWhitelist returns the list of values that were appended to the "ip_whitelist" field in this mutation.
func (m *APIKeyMutation) AppendedIPWhitelist() ([]string, bool) {
if len(m.appendip_whitelist) == 0 {
return nil, false
}
return m.appendip_whitelist, true
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (m *APIKeyMutation) ClearIPWhitelist() {
m.ip_whitelist = nil
m.appendip_whitelist = nil
m.clearedFields[apikey.FieldIPWhitelist] = struct{}{}
}
// IPWhitelistCleared returns if the "ip_whitelist" field was cleared in this mutation.
func (m *APIKeyMutation) IPWhitelistCleared() bool {
_, ok := m.clearedFields[apikey.FieldIPWhitelist]
return ok
}
// ResetIPWhitelist resets all changes to the "ip_whitelist" field.
func (m *APIKeyMutation) ResetIPWhitelist() {
m.ip_whitelist = nil
m.appendip_whitelist = nil
delete(m.clearedFields, apikey.FieldIPWhitelist)
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (m *APIKeyMutation) SetIPBlacklist(s []string) {
m.ip_blacklist = &s
m.appendip_blacklist = nil
}
// IPBlacklist returns the value of the "ip_blacklist" field in the mutation.
func (m *APIKeyMutation) IPBlacklist() (r []string, exists bool) {
v := m.ip_blacklist
if v == nil {
return
}
return *v, true
}
// OldIPBlacklist returns the old "ip_blacklist" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldIPBlacklist(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPBlacklist is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPBlacklist requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPBlacklist: %w", err)
}
return oldValue.IPBlacklist, nil
}
// AppendIPBlacklist adds s to the "ip_blacklist" field.
func (m *APIKeyMutation) AppendIPBlacklist(s []string) {
m.appendip_blacklist = append(m.appendip_blacklist, s...)
}
// AppendedIPBlacklist returns the list of values that were appended to the "ip_blacklist" field in this mutation.
func (m *APIKeyMutation) AppendedIPBlacklist() ([]string, bool) {
if len(m.appendip_blacklist) == 0 {
return nil, false
}
return m.appendip_blacklist, true
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (m *APIKeyMutation) ClearIPBlacklist() {
m.ip_blacklist = nil
m.appendip_blacklist = nil
m.clearedFields[apikey.FieldIPBlacklist] = struct{}{}
}
// IPBlacklistCleared returns if the "ip_blacklist" field was cleared in this mutation.
func (m *APIKeyMutation) IPBlacklistCleared() bool {
_, ok := m.clearedFields[apikey.FieldIPBlacklist]
return ok
}
// ResetIPBlacklist resets all changes to the "ip_blacklist" field.
func (m *APIKeyMutation) ResetIPBlacklist() {
m.ip_blacklist = nil
m.appendip_blacklist = nil
delete(m.clearedFields, apikey.FieldIPBlacklist)
}
// ClearUser clears the "user" edge to the User entity. // ClearUser clears the "user" edge to the User entity.
func (m *APIKeyMutation) ClearUser() { func (m *APIKeyMutation) ClearUser() {
m.cleareduser = true m.cleareduser = true
...@@ -630,7 +764,7 @@ func (m *APIKeyMutation) Type() string { ...@@ -630,7 +764,7 @@ func (m *APIKeyMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *APIKeyMutation) Fields() []string { func (m *APIKeyMutation) Fields() []string {
fields := make([]string, 0, 8) fields := make([]string, 0, 10)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, apikey.FieldCreatedAt) fields = append(fields, apikey.FieldCreatedAt)
} }
...@@ -655,6 +789,12 @@ func (m *APIKeyMutation) Fields() []string { ...@@ -655,6 +789,12 @@ func (m *APIKeyMutation) Fields() []string {
if m.status != nil { if m.status != nil {
fields = append(fields, apikey.FieldStatus) fields = append(fields, apikey.FieldStatus)
} }
if m.ip_whitelist != nil {
fields = append(fields, apikey.FieldIPWhitelist)
}
if m.ip_blacklist != nil {
fields = append(fields, apikey.FieldIPBlacklist)
}
return fields return fields
} }
...@@ -679,6 +819,10 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { ...@@ -679,6 +819,10 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) {
return m.GroupID() return m.GroupID()
case apikey.FieldStatus: case apikey.FieldStatus:
return m.Status() return m.Status()
case apikey.FieldIPWhitelist:
return m.IPWhitelist()
case apikey.FieldIPBlacklist:
return m.IPBlacklist()
} }
return nil, false return nil, false
} }
...@@ -704,6 +848,10 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, ...@@ -704,6 +848,10 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldGroupID(ctx) return m.OldGroupID(ctx)
case apikey.FieldStatus: case apikey.FieldStatus:
return m.OldStatus(ctx) return m.OldStatus(ctx)
case apikey.FieldIPWhitelist:
return m.OldIPWhitelist(ctx)
case apikey.FieldIPBlacklist:
return m.OldIPBlacklist(ctx)
} }
return nil, fmt.Errorf("unknown APIKey field %s", name) return nil, fmt.Errorf("unknown APIKey field %s", name)
} }
...@@ -769,6 +917,20 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { ...@@ -769,6 +917,20 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error {
} }
m.SetStatus(v) m.SetStatus(v)
return nil return nil
case apikey.FieldIPWhitelist:
v, ok := value.([]string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPWhitelist(v)
return nil
case apikey.FieldIPBlacklist:
v, ok := value.([]string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPBlacklist(v)
return nil
} }
return fmt.Errorf("unknown APIKey field %s", name) return fmt.Errorf("unknown APIKey field %s", name)
} }
...@@ -808,6 +970,12 @@ func (m *APIKeyMutation) ClearedFields() []string { ...@@ -808,6 +970,12 @@ func (m *APIKeyMutation) ClearedFields() []string {
if m.FieldCleared(apikey.FieldGroupID) { if m.FieldCleared(apikey.FieldGroupID) {
fields = append(fields, apikey.FieldGroupID) fields = append(fields, apikey.FieldGroupID)
} }
if m.FieldCleared(apikey.FieldIPWhitelist) {
fields = append(fields, apikey.FieldIPWhitelist)
}
if m.FieldCleared(apikey.FieldIPBlacklist) {
fields = append(fields, apikey.FieldIPBlacklist)
}
return fields return fields
} }
...@@ -828,6 +996,12 @@ func (m *APIKeyMutation) ClearField(name string) error { ...@@ -828,6 +996,12 @@ func (m *APIKeyMutation) ClearField(name string) error {
case apikey.FieldGroupID: case apikey.FieldGroupID:
m.ClearGroupID() m.ClearGroupID()
return nil return nil
case apikey.FieldIPWhitelist:
m.ClearIPWhitelist()
return nil
case apikey.FieldIPBlacklist:
m.ClearIPBlacklist()
return nil
} }
return fmt.Errorf("unknown APIKey nullable field %s", name) return fmt.Errorf("unknown APIKey nullable field %s", name)
} }
...@@ -860,6 +1034,12 @@ func (m *APIKeyMutation) ResetField(name string) error { ...@@ -860,6 +1034,12 @@ func (m *APIKeyMutation) ResetField(name string) error {
case apikey.FieldStatus: case apikey.FieldStatus:
m.ResetStatus() m.ResetStatus()
return nil return nil
case apikey.FieldIPWhitelist:
m.ResetIPWhitelist()
return nil
case apikey.FieldIPBlacklist:
m.ResetIPBlacklist()
return nil
} }
return fmt.Errorf("unknown APIKey field %s", name) return fmt.Errorf("unknown APIKey field %s", name)
} }
...@@ -8396,6 +8576,7 @@ type UsageLogMutation struct { ...@@ -8396,6 +8576,7 @@ type UsageLogMutation struct {
first_token_ms *int first_token_ms *int
addfirst_token_ms *int addfirst_token_ms *int
user_agent *string user_agent *string
ip_address *string
image_count *int image_count *int
addimage_count *int addimage_count *int
image_size *string image_size *string
...@@ -9801,6 +9982,55 @@ func (m *UsageLogMutation) ResetUserAgent() { ...@@ -9801,6 +9982,55 @@ func (m *UsageLogMutation) ResetUserAgent() {
delete(m.clearedFields, usagelog.FieldUserAgent) delete(m.clearedFields, usagelog.FieldUserAgent)
} }
// SetIPAddress sets the "ip_address" field.
func (m *UsageLogMutation) SetIPAddress(s string) {
m.ip_address = &s
}
// IPAddress returns the value of the "ip_address" field in the mutation.
func (m *UsageLogMutation) IPAddress() (r string, exists bool) {
v := m.ip_address
if v == nil {
return
}
return *v, true
}
// OldIPAddress returns the old "ip_address" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldIPAddress(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPAddress is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPAddress requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPAddress: %w", err)
}
return oldValue.IPAddress, nil
}
// ClearIPAddress clears the value of the "ip_address" field.
func (m *UsageLogMutation) ClearIPAddress() {
m.ip_address = nil
m.clearedFields[usagelog.FieldIPAddress] = struct{}{}
}
// IPAddressCleared returns if the "ip_address" field was cleared in this mutation.
func (m *UsageLogMutation) IPAddressCleared() bool {
_, ok := m.clearedFields[usagelog.FieldIPAddress]
return ok
}
// ResetIPAddress resets all changes to the "ip_address" field.
func (m *UsageLogMutation) ResetIPAddress() {
m.ip_address = nil
delete(m.clearedFields, usagelog.FieldIPAddress)
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (m *UsageLogMutation) SetImageCount(i int) { func (m *UsageLogMutation) SetImageCount(i int) {
m.image_count = &i m.image_count = &i
...@@ -10111,7 +10341,7 @@ func (m *UsageLogMutation) Type() string { ...@@ -10111,7 +10341,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UsageLogMutation) Fields() []string { func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 28) fields := make([]string, 0, 29)
if m.user != nil { if m.user != nil {
fields = append(fields, usagelog.FieldUserID) fields = append(fields, usagelog.FieldUserID)
} }
...@@ -10187,6 +10417,9 @@ func (m *UsageLogMutation) Fields() []string { ...@@ -10187,6 +10417,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.user_agent != nil { if m.user_agent != nil {
fields = append(fields, usagelog.FieldUserAgent) fields = append(fields, usagelog.FieldUserAgent)
} }
if m.ip_address != nil {
fields = append(fields, usagelog.FieldIPAddress)
}
if m.image_count != nil { if m.image_count != nil {
fields = append(fields, usagelog.FieldImageCount) fields = append(fields, usagelog.FieldImageCount)
} }
...@@ -10254,6 +10487,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { ...@@ -10254,6 +10487,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.FirstTokenMs() return m.FirstTokenMs()
case usagelog.FieldUserAgent: case usagelog.FieldUserAgent:
return m.UserAgent() return m.UserAgent()
case usagelog.FieldIPAddress:
return m.IPAddress()
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
return m.ImageCount() return m.ImageCount()
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
...@@ -10319,6 +10554,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value ...@@ -10319,6 +10554,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldFirstTokenMs(ctx) return m.OldFirstTokenMs(ctx)
case usagelog.FieldUserAgent: case usagelog.FieldUserAgent:
return m.OldUserAgent(ctx) return m.OldUserAgent(ctx)
case usagelog.FieldIPAddress:
return m.OldIPAddress(ctx)
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
return m.OldImageCount(ctx) return m.OldImageCount(ctx)
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
...@@ -10509,6 +10746,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { ...@@ -10509,6 +10746,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
} }
m.SetUserAgent(v) m.SetUserAgent(v)
return nil return nil
case usagelog.FieldIPAddress:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPAddress(v)
return nil
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
v, ok := value.(int) v, ok := value.(int)
if !ok { if !ok {
...@@ -10782,6 +11026,9 @@ func (m *UsageLogMutation) ClearedFields() []string { ...@@ -10782,6 +11026,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldUserAgent) { if m.FieldCleared(usagelog.FieldUserAgent) {
fields = append(fields, usagelog.FieldUserAgent) fields = append(fields, usagelog.FieldUserAgent)
} }
if m.FieldCleared(usagelog.FieldIPAddress) {
fields = append(fields, usagelog.FieldIPAddress)
}
if m.FieldCleared(usagelog.FieldImageSize) { if m.FieldCleared(usagelog.FieldImageSize) {
fields = append(fields, usagelog.FieldImageSize) fields = append(fields, usagelog.FieldImageSize)
} }
...@@ -10814,6 +11061,9 @@ func (m *UsageLogMutation) ClearField(name string) error { ...@@ -10814,6 +11061,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldUserAgent: case usagelog.FieldUserAgent:
m.ClearUserAgent() m.ClearUserAgent()
return nil return nil
case usagelog.FieldIPAddress:
m.ClearIPAddress()
return nil
case usagelog.FieldImageSize: case usagelog.FieldImageSize:
m.ClearImageSize() m.ClearImageSize()
return nil return nil
...@@ -10900,6 +11150,9 @@ func (m *UsageLogMutation) ResetField(name string) error { ...@@ -10900,6 +11150,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldUserAgent: case usagelog.FieldUserAgent:
m.ResetUserAgent() m.ResetUserAgent()
return nil return nil
case usagelog.FieldIPAddress:
m.ResetIPAddress()
return nil
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
m.ResetImageCount() m.ResetImageCount()
return nil return nil
......
...@@ -533,16 +533,20 @@ func init() { ...@@ -533,16 +533,20 @@ func init() {
usagelogDescUserAgent := usagelogFields[24].Descriptor() usagelogDescUserAgent := usagelogFields[24].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[25].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field. // usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[25].Descriptor() usagelogDescImageCount := usagelogFields[26].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field. // usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field. // usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[26].Descriptor() usagelogDescImageSize := usagelogFields[27].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescCreatedAt is the schema descriptor for created_at field. // usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[27].Descriptor() usagelogDescCreatedAt := usagelogFields[28].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin() userMixin := schema.User{}.Mixin()
......
...@@ -46,6 +46,12 @@ func (APIKey) Fields() []ent.Field { ...@@ -46,6 +46,12 @@ func (APIKey) Fields() []ent.Field {
field.String("status"). field.String("status").
MaxLen(20). MaxLen(20).
Default(service.StatusActive), Default(service.StatusActive),
field.JSON("ip_whitelist", []string{}).
Optional().
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),
field.JSON("ip_blacklist", []string{}).
Optional().
Comment("Blocked IPs/CIDRs"),
} }
} }
......
...@@ -100,6 +100,10 @@ func (UsageLog) Fields() []ent.Field { ...@@ -100,6 +100,10 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(512). MaxLen(512).
Optional(). Optional().
Nillable(), Nillable(),
field.String("ip_address").
MaxLen(45). // 支持 IPv6
Optional().
Nillable(),
// 图片生成字段(仅 gemini-3-pro-image 等图片模型使用) // 图片生成字段(仅 gemini-3-pro-image 等图片模型使用)
field.Int("image_count"). field.Int("image_count").
......
...@@ -72,6 +72,8 @@ type UsageLog struct { ...@@ -72,6 +72,8 @@ type UsageLog struct {
FirstTokenMs *int `json:"first_token_ms,omitempty"` FirstTokenMs *int `json:"first_token_ms,omitempty"`
// UserAgent holds the value of the "user_agent" field. // UserAgent holds the value of the "user_agent" field.
UserAgent *string `json:"user_agent,omitempty"` UserAgent *string `json:"user_agent,omitempty"`
// IPAddress holds the value of the "ip_address" field.
IPAddress *string `json:"ip_address,omitempty"`
// ImageCount holds the value of the "image_count" field. // ImageCount holds the value of the "image_count" field.
ImageCount int `json:"image_count,omitempty"` ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field. // ImageSize holds the value of the "image_size" field.
...@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { ...@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldImageSize: case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
...@@ -347,6 +349,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { ...@@ -347,6 +349,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.UserAgent = new(string) _m.UserAgent = new(string)
*_m.UserAgent = value.String *_m.UserAgent = value.String
} }
case usagelog.FieldIPAddress:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field ip_address", values[i])
} else if value.Valid {
_m.IPAddress = new(string)
*_m.IPAddress = value.String
}
case usagelog.FieldImageCount: case usagelog.FieldImageCount:
if value, ok := values[i].(*sql.NullInt64); !ok { if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field image_count", values[i]) return fmt.Errorf("unexpected type %T for field image_count", values[i])
...@@ -512,6 +521,11 @@ func (_m *UsageLog) String() string { ...@@ -512,6 +521,11 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v) builder.WriteString(*v)
} }
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.IPAddress; v != nil {
builder.WriteString("ip_address=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("image_count=") builder.WriteString("image_count=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageCount)) builder.WriteString(fmt.Sprintf("%v", _m.ImageCount))
builder.WriteString(", ") builder.WriteString(", ")
......
...@@ -64,6 +64,8 @@ const ( ...@@ -64,6 +64,8 @@ const (
FieldFirstTokenMs = "first_token_ms" FieldFirstTokenMs = "first_token_ms"
// FieldUserAgent holds the string denoting the user_agent field in the database. // FieldUserAgent holds the string denoting the user_agent field in the database.
FieldUserAgent = "user_agent" FieldUserAgent = "user_agent"
// FieldIPAddress holds the string denoting the ip_address field in the database.
FieldIPAddress = "ip_address"
// FieldImageCount holds the string denoting the image_count field in the database. // FieldImageCount holds the string denoting the image_count field in the database.
FieldImageCount = "image_count" FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database. // FieldImageSize holds the string denoting the image_size field in the database.
...@@ -147,6 +149,7 @@ var Columns = []string{ ...@@ -147,6 +149,7 @@ var Columns = []string{
FieldDurationMs, FieldDurationMs,
FieldFirstTokenMs, FieldFirstTokenMs,
FieldUserAgent, FieldUserAgent,
FieldIPAddress,
FieldImageCount, FieldImageCount,
FieldImageSize, FieldImageSize,
FieldCreatedAt, FieldCreatedAt,
...@@ -199,6 +202,8 @@ var ( ...@@ -199,6 +202,8 @@ var (
DefaultStream bool DefaultStream bool
// UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. // UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
UserAgentValidator func(string) error UserAgentValidator func(string) error
// IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
IPAddressValidator func(string) error
// DefaultImageCount holds the default value on creation for the "image_count" field. // DefaultImageCount holds the default value on creation for the "image_count" field.
DefaultImageCount int DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
...@@ -340,6 +345,11 @@ func ByUserAgent(opts ...sql.OrderTermOption) OrderOption { ...@@ -340,6 +345,11 @@ func ByUserAgent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserAgent, opts...).ToFunc() return sql.OrderByField(FieldUserAgent, opts...).ToFunc()
} }
// ByIPAddress orders the results by the ip_address field.
func ByIPAddress(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIPAddress, opts...).ToFunc()
}
// ByImageCount orders the results by the image_count field. // ByImageCount orders the results by the image_count field.
func ByImageCount(opts ...sql.OrderTermOption) OrderOption { func ByImageCount(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageCount, opts...).ToFunc() return sql.OrderByField(FieldImageCount, opts...).ToFunc()
......
...@@ -180,6 +180,11 @@ func UserAgent(v string) predicate.UsageLog { ...@@ -180,6 +180,11 @@ func UserAgent(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v)) return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v))
} }
// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ.
func IPAddress(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
}
// ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ. // ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ.
func ImageCount(v int) predicate.UsageLog { func ImageCount(v int) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v)) return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
...@@ -1190,6 +1195,81 @@ func UserAgentContainsFold(v string) predicate.UsageLog { ...@@ -1190,6 +1195,81 @@ func UserAgentContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v)) return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v))
} }
// IPAddressEQ applies the EQ predicate on the "ip_address" field.
func IPAddressEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
}
// IPAddressNEQ applies the NEQ predicate on the "ip_address" field.
func IPAddressNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldIPAddress, v))
}
// IPAddressIn applies the In predicate on the "ip_address" field.
func IPAddressIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldIPAddress, vs...))
}
// IPAddressNotIn applies the NotIn predicate on the "ip_address" field.
func IPAddressNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldIPAddress, vs...))
}
// IPAddressGT applies the GT predicate on the "ip_address" field.
func IPAddressGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldIPAddress, v))
}
// IPAddressGTE applies the GTE predicate on the "ip_address" field.
func IPAddressGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldIPAddress, v))
}
// IPAddressLT applies the LT predicate on the "ip_address" field.
func IPAddressLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldIPAddress, v))
}
// IPAddressLTE applies the LTE predicate on the "ip_address" field.
func IPAddressLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldIPAddress, v))
}
// IPAddressContains applies the Contains predicate on the "ip_address" field.
func IPAddressContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldIPAddress, v))
}
// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field.
func IPAddressHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldIPAddress, v))
}
// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field.
func IPAddressHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldIPAddress, v))
}
// IPAddressIsNil applies the IsNil predicate on the "ip_address" field.
func IPAddressIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldIPAddress))
}
// IPAddressNotNil applies the NotNil predicate on the "ip_address" field.
func IPAddressNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldIPAddress))
}
// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field.
func IPAddressEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldIPAddress, v))
}
// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field.
func IPAddressContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldIPAddress, v))
}
// ImageCountEQ applies the EQ predicate on the "image_count" field. // ImageCountEQ applies the EQ predicate on the "image_count" field.
func ImageCountEQ(v int) predicate.UsageLog { func ImageCountEQ(v int) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v)) return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
......
...@@ -337,6 +337,20 @@ func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate { ...@@ -337,6 +337,20 @@ func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate {
return _c return _c
} }
// SetIPAddress sets the "ip_address" field.
func (_c *UsageLogCreate) SetIPAddress(v string) *UsageLogCreate {
_c.mutation.SetIPAddress(v)
return _c
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableIPAddress(v *string) *UsageLogCreate {
if v != nil {
_c.SetIPAddress(*v)
}
return _c
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate { func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate {
_c.mutation.SetImageCount(v) _c.mutation.SetImageCount(v)
...@@ -586,6 +600,11 @@ func (_c *UsageLogCreate) check() error { ...@@ -586,6 +600,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
} }
} }
if v, ok := _c.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if _, ok := _c.mutation.ImageCount(); !ok { if _, ok := _c.mutation.ImageCount(); !ok {
return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)} return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)}
} }
...@@ -713,6 +732,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { ...@@ -713,6 +732,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value) _spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
_node.UserAgent = &value _node.UserAgent = &value
} }
if value, ok := _c.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
_node.IPAddress = &value
}
if value, ok := _c.mutation.ImageCount(); ok { if value, ok := _c.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
_node.ImageCount = value _node.ImageCount = value
...@@ -1288,6 +1311,24 @@ func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert { ...@@ -1288,6 +1311,24 @@ func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert {
return u return u
} }
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsert) SetIPAddress(v string) *UsageLogUpsert {
u.Set(usagelog.FieldIPAddress, v)
return u
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateIPAddress() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldIPAddress)
return u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsert) ClearIPAddress() *UsageLogUpsert {
u.SetNull(usagelog.FieldIPAddress)
return u
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert { func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert {
u.Set(usagelog.FieldImageCount, v) u.Set(usagelog.FieldImageCount, v)
...@@ -1866,6 +1907,27 @@ func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne { ...@@ -1866,6 +1907,27 @@ func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne {
}) })
} }
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsertOne) SetIPAddress(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetIPAddress(v)
})
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateIPAddress() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateIPAddress()
})
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsertOne) ClearIPAddress() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearIPAddress()
})
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne { func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
...@@ -2616,6 +2678,27 @@ func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk { ...@@ -2616,6 +2678,27 @@ func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk {
}) })
} }
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsertBulk) SetIPAddress(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetIPAddress(v)
})
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateIPAddress() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateIPAddress()
})
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsertBulk) ClearIPAddress() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearIPAddress()
})
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk { func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
......
...@@ -524,6 +524,26 @@ func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate { ...@@ -524,6 +524,26 @@ func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate {
return _u return _u
} }
// SetIPAddress sets the "ip_address" field.
func (_u *UsageLogUpdate) SetIPAddress(v string) *UsageLogUpdate {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableIPAddress(v *string) *UsageLogUpdate {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *UsageLogUpdate) ClearIPAddress() *UsageLogUpdate {
_u.mutation.ClearIPAddress()
return _u
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate { func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate {
_u.mutation.ResetImageCount() _u.mutation.ResetImageCount()
...@@ -669,6 +689,11 @@ func (_u *UsageLogUpdate) check() error { ...@@ -669,6 +689,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
} }
} }
if v, ok := _u.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSize(); ok { if v, ok := _u.mutation.ImageSize(); ok {
if err := usagelog.ImageSizeValidator(v); err != nil { if err := usagelog.ImageSizeValidator(v); err != nil {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
...@@ -815,6 +840,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -815,6 +840,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.UserAgentCleared() { if _u.mutation.UserAgentCleared() {
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString) _spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
} }
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.ImageCount(); ok { if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
} }
...@@ -1484,6 +1515,26 @@ func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne { ...@@ -1484,6 +1515,26 @@ func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne {
return _u return _u
} }
// SetIPAddress sets the "ip_address" field.
func (_u *UsageLogUpdateOne) SetIPAddress(v string) *UsageLogUpdateOne {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableIPAddress(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *UsageLogUpdateOne) ClearIPAddress() *UsageLogUpdateOne {
_u.mutation.ClearIPAddress()
return _u
}
// SetImageCount sets the "image_count" field. // SetImageCount sets the "image_count" field.
func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne { func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne {
_u.mutation.ResetImageCount() _u.mutation.ResetImageCount()
...@@ -1642,6 +1693,11 @@ func (_u *UsageLogUpdateOne) check() error { ...@@ -1642,6 +1693,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
} }
} }
if v, ok := _u.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSize(); ok { if v, ok := _u.mutation.ImageSize(); ok {
if err := usagelog.ImageSizeValidator(v); err != nil { if err := usagelog.ImageSizeValidator(v); err != nil {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
...@@ -1805,6 +1861,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err ...@@ -1805,6 +1861,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.UserAgentCleared() { if _u.mutation.UserAgentCleared() {
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString) _spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
} }
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.ImageCount(); ok { if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"log" "log"
"net/url"
"os" "os"
"strings" "strings"
"time" "time"
...@@ -43,6 +44,7 @@ type Config struct { ...@@ -43,6 +44,7 @@ type Config struct {
Database DatabaseConfig `mapstructure:"database"` Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"` Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"` JWT JWTConfig `mapstructure:"jwt"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"` Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"` RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"` Pricing PricingConfig `mapstructure:"pricing"`
...@@ -322,6 +324,30 @@ type TurnstileConfig struct { ...@@ -322,6 +324,30 @@ type TurnstileConfig struct {
Required bool `mapstructure:"required"` Required bool `mapstructure:"required"`
} }
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
//
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
// 这里是用于登录 Sub2API 本身的用户体系。
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type DefaultConfig struct { type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"` AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"` AdminPassword string `mapstructure:"admin_password"`
...@@ -388,6 +414,18 @@ func Load() (*Config, error) { ...@@ -388,6 +414,18 @@ func Load() (*Config, error) {
cfg.Server.Mode = "debug" cfg.Server.Mode = "debug"
} }
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
...@@ -426,6 +464,81 @@ func Load() (*Config, error) { ...@@ -426,6 +464,81 @@ func Load() (*Config, error) {
return &cfg, nil return &cfg, nil
} }
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
func ValidateAbsoluteHTTPURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
// ValidateFrontendRedirectURL 校验前端回调地址:
// - 允许同源相对路径(以 / 开头)
// - 或绝对 http(s) URL(禁止 fragment)
func ValidateFrontendRedirectURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
if strings.ContainsAny(raw, "\r\n") {
return fmt.Errorf("contains invalid characters")
}
if strings.HasPrefix(raw, "/") {
if strings.HasPrefix(raw, "//") {
return fmt.Errorf("must not start with //")
}
return nil
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute http(s) url or relative path")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
}
func warnIfInsecureURL(field, raw string) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
}
}
func setDefaults() { func setDefaults() {
viper.SetDefault("run_mode", RunModeStandard) viper.SetDefault("run_mode", RunModeStandard)
...@@ -475,6 +588,22 @@ func setDefaults() { ...@@ -475,6 +588,22 @@ func setDefaults() {
// Turnstile // Turnstile
viper.SetDefault("turnstile.required", false) viper.SetDefault("turnstile.required", false)
// LinuxDo Connect OAuth 登录(终端用户 SSO)
viper.SetDefault("linuxdo_connect.enabled", false)
viper.SetDefault("linuxdo_connect.client_id", "")
viper.SetDefault("linuxdo_connect.client_secret", "")
viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize")
viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token")
viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user")
viper.SetDefault("linuxdo_connect.scopes", "user")
viper.SetDefault("linuxdo_connect.redirect_url", "")
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
viper.SetDefault("linuxdo_connect.use_pkce", false)
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
// Database // Database
viper.SetDefault("database.host", "localhost") viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432) viper.SetDefault("database.port", 5432)
...@@ -544,7 +673,7 @@ func setDefaults() { ...@@ -544,7 +673,7 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10) viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024) viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
...@@ -586,6 +715,60 @@ func (c *Config) Validate() error { ...@@ -586,6 +715,60 @@ func (c *Config) Validate() error {
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
return fmt.Errorf("security.csp.policy is required when CSP is enabled") return fmt.Errorf("security.csp.policy is required when CSP is enabled")
} }
if c.LinuxDo.Enabled {
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" {
return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.TokenURL) == "" {
return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" {
return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" {
return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true")
}
method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic", "none":
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
}
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true")
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil {
return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err)
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil {
return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err)
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil {
return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err)
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil {
return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err)
}
if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil {
return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err)
}
warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL)
warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL)
warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL)
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.Billing.CircuitBreaker.Enabled { if c.Billing.CircuitBreaker.Enabled {
if c.Billing.CircuitBreaker.FailureThreshold <= 0 { if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
......
package config package config
import ( import (
"strings"
"testing" "testing"
"time" "time"
...@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { ...@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
t.Fatalf("ResponseHeaders.Enabled = true, want false") t.Fatalf("ResponseHeaders.Enabled = true, want false")
} }
} }
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "test-client"
cfg.LinuxDo.ClientSecret = "test-secret"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
cfg.LinuxDo.UsePKCE = false
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for javascript scheme, got nil")
}
if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") {
t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err)
}
}
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "test-client"
cfg.LinuxDo.ClientSecret = ""
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "none"
cfg.LinuxDo.UsePKCE = false
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
}
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment