Commit 3a67002c authored by IanShaw027's avatar IanShaw027
Browse files

merge: 合并主分支改动并保留 ops 监控实现

合并 main 分支的最新改动到 ops 监控分支。
冲突解决策略:保留当前分支的 ops 相关改动,接受主分支的其他改动。

保留的 ops 改动:
- 运维监控配置和依赖注入
- 运维监控 API 处理器和中间件
- 运维监控服务层和数据访问层
- 运维监控前端界面和状态管理

接受的主分支改动:
- Linux DO OAuth 集成
- 账号过期功能
- IP 地址限制功能
- 用量统计优化
- 其他 bug 修复和功能改进
parents c48dc097 7d1fe818
...@@ -2,6 +2,7 @@ package repository ...@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache { ...@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb} return &gatewayCache{rdb: rdb}
} }
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { // buildSessionKey 构建 session key,包含 groupID 实现分组隔离
key := stickySessionPrefix + sessionHash // 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64() return c.rdb.Get(ctx, key).Int64()
} }
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash key := buildSessionKey(groupID, sessionHash)
return c.rdb.Set(ctx, key, accountID, ttl).Err() return c.rdb.Set(ctx, key, accountID, ttl).Err()
} }
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err() return c.rdb.Expire(ctx, key, ttl).Err()
} }
...@@ -24,18 +24,19 @@ func (s *GatewayCacheSuite) SetupTest() { ...@@ -24,18 +24,19 @@ func (s *GatewayCacheSuite) SetupTest() {
} }
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() { func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent") _, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session") require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
} }
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() { func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1" sessionID := "s1"
accountID := int64(99) accountID := int64(99)
groupID := int64(1)
sessionTTL := 1 * time.Minute sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID") require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID) sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID") require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch") require.Equal(s.T(), accountID, sid, "session id mismatch")
} }
...@@ -43,11 +44,12 @@ func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() { ...@@ -43,11 +44,12 @@ func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() { func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2" sessionID := "s2"
accountID := int64(100) accountID := int64(100)
groupID := int64(1)
sessionTTL := 1 * time.Minute sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID") require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := stickySessionPrefix + sessionID sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result() ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set") require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL) s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
...@@ -56,14 +58,15 @@ func (s *GatewayCacheSuite) TestSessionAccountID_TTL() { ...@@ -56,14 +58,15 @@ func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL() { func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3" sessionID := "s3"
accountID := int64(101) accountID := int64(101)
groupID := int64(1)
initialTTL := 1 * time.Minute initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID") require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL") require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := stickySessionPrefix + sessionID sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result() ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh") require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL) s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
...@@ -71,18 +74,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL() { ...@@ -71,18 +74,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() { func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op) // RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute) err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error") require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
} }
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted" sessionID := "corrupted"
sessionKey := stickySessionPrefix + sessionID groupID := int64(1)
sessionKey := buildSessionKey(groupID, sessionID)
// Set a non-integer value // Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value") require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID) _, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.Error(s.T(), err, "expected error for corrupted value") require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
} }
......
...@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c ...@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
// Use different OAuth clients based on oauthType: // Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public) // - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client // - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client // - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{ oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID, ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes, Scopes: c.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" { if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = "" oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = "" oauthCfgInput.ClientSecret = ""
} }
...@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh ...@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes, Scopes: c.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" { if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = "" oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = "" oauthCfgInput.ClientSecret = ""
} }
......
...@@ -14,23 +14,33 @@ import ( ...@@ -14,23 +14,33 @@ import (
) )
type githubReleaseClient struct { type githubReleaseClient struct {
httpClient *http.Client httpClient *http.Client
allowPrivateHosts bool downloadHTTPClient *http.Client
} }
func NewGitHubReleaseClient() service.GitHubReleaseClient { // NewGitHubReleaseClient 创建 GitHub Release 客户端
allowPrivate := false // proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true, ProxyURL: proxyURL,
AllowPrivateHosts: allowPrivate,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second} sharedClient = &http.Client{Timeout: 30 * time.Second}
} }
// 下载客户端需要更长的超时时间
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ProxyURL: proxyURL,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
return &githubReleaseClient{ return &githubReleaseClient{
httpClient: sharedClient, httpClient: sharedClient,
allowPrivateHosts: allowPrivate, downloadHTTPClient: downloadClient,
} }
} }
...@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string ...@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return err return err
} }
downloadClient, err := httpclient.GetClient(httpclient.Options{ // 使用预配置的下载客户端(已包含代理配置)
Timeout: 10 * time.Minute, resp, err := c.downloadHTTPClient.Do(req)
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
resp, err := downloadClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -39,8 +39,8 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -39,8 +39,8 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
func newTestGitHubReleaseClient() *githubReleaseClient { func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{ return &githubReleaseClient{
httpClient: &http.Client{}, httpClient: &http.Client{},
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
} }
...@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { ...@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo") release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { ...@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo") _, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { ...@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo") _, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
...@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { ...@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{ httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL}, Transport: &testTransport{testServerURL: s.srv.URL},
}, },
allowPrivateHosts: true, downloadHTTPClient: &http.Client{},
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
......
...@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays) SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
created, err := builder.Save(ctx) created, err := builder.Save(ctx)
if err == nil { if err == nil {
...@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group ...@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
} }
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error { func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
updated, err := r.client.Group.UpdateOneID(groupIn.ID). builder := r.client.Group.UpdateOneID(groupIn.ID).
SetName(groupIn.Name). SetName(groupIn.Name).
SetDescription(groupIn.Description). SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform). SetPlatform(groupIn.Platform).
...@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx) SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
} else {
builder = builder.ClearFallbackGroupID()
}
updated, err := builder.Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
} }
...@@ -101,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { ...@@ -101,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
} }
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", "", nil)
} }
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
q := r.client.Group.Query() q := r.client.Group.Query()
if platform != "" { if platform != "" {
...@@ -113,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -113,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
if status != "" { if status != "" {
q = q.Where(group.StatusEQ(status)) q = q.Where(group.StatusEQ(status))
} }
if search != "" {
q = q.Where(group.Or(
group.NameContainsFold(search),
group.DescriptionContainsFold(search),
))
}
if isExclusive != nil { if isExclusive != nil {
q = q.Where(group.IsExclusiveEQ(*isExclusive)) q = q.Where(group.IsExclusiveEQ(*isExclusive))
} }
......
...@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { ...@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
pagination.PaginationParams{Page: 1, PageSize: 10}, pagination.PaginationParams{Page: 1, PageSize: 10},
service.PlatformOpenAI, service.PlatformOpenAI,
"", "",
"",
nil, nil,
) )
s.Require().NoError(err, "ListWithFilters base") s.Require().NoError(err, "ListWithFilters base")
...@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { ...@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
SubscriptionType: service.SubscriptionTypeStandard, SubscriptionType: service.SubscriptionTypeStandard,
})) }))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, len(baseGroups)+1) s.Require().Len(groups, len(baseGroups)+1)
// Verify all groups are OpenAI platform // Verify all groups are OpenAI platform
...@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() { ...@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
SubscriptionType: service.SubscriptionTypeStandard, SubscriptionType: service.SubscriptionTypeStandard,
})) }))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal(service.StatusDisabled, groups[0].Status) s.Require().Equal(service.StatusDisabled, groups[0].Status)
...@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { ...@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
})) }))
isExclusive := true isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive) s.Require().True(groups[0].IsExclusive)
} }
func (s *GroupRepoSuite) TestListWithFilters_Search() {
newRepo := func() (*groupRepository, context.Context) {
tx := testEntTx(s.T())
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
}
containsID := func(groups []service.Group, id int64) bool {
for i := range groups {
if groups[i].ID == id {
return true
}
}
return false
}
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
s.Require().NoError(repo.Create(ctx, g))
s.Require().NotZero(g.ID)
return g
}
newGroup := func(name string) *service.Group {
return &service.Group{
Name: name,
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
}
s.Run("search_name_should_match", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_description_should_match", func() {
repo, ctx := newRepo()
target := newGroup("it-group-search-desc-target")
target.Description = "something about desc-needle in here"
target = mustCreate(repo, ctx, target)
other := newGroup("it-group-search-desc-other")
other.Description = "nothing to see here"
other = mustCreate(repo, ctx, other)
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_nonexistent_should_return_empty", func() {
repo, ctx := newRepo()
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
search := s.T().Name() + "__no_such_group__"
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
s.Require().NoError(err)
s.Require().Empty(groups)
})
s.Run("search_should_be_case_insensitive", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_should_escape_like_wildcards", func() {
repo, ctx := newRepo()
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
})
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{ g1 := &service.Group{
Name: "g1", Name: "g1",
...@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { ...@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
s.Require().NoError(err) s.Require().NoError(err)
isExclusive := true isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive) groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
...@@ -17,17 +16,12 @@ type pricingRemoteClient struct { ...@@ -17,17 +16,12 @@ type pricingRemoteClient struct {
httpClient *http.Client httpClient *http.Client
} }
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { // NewPricingRemoteClient 创建定价数据远程客户端
allowPrivate := false // proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
validateResolvedIP := true func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: validateResolvedIP, ProxyURL: proxyURL,
AllowPrivateHosts: allowPrivate,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second} sharedClient = &http.Client{Timeout: 30 * time.Second}
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
...@@ -20,13 +19,7 @@ type PricingServiceSuite struct { ...@@ -20,13 +19,7 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() { func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
client, ok := NewPricingRemoteClient(&config.Config{ client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
} }
......
...@@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return outProxies, paginationResultFromTotal(int64(total), params), nil return outProxies, paginationResultFromTotal(int64(total), params), nil
} }
// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy
func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
q := r.client.Proxy.Query()
if protocol != "" {
q = q.Where(proxy.ProtocolEQ(protocol))
}
if status != "" {
q = q.Where(proxy.StatusEQ(status))
}
if search != "" {
q = q.Where(proxy.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
proxies, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, nil, err
}
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
if proxyOut == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxyOut,
AccountCount: counts[proxyOut.ID],
})
}
return result, paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
proxies, err := r.client.Proxy.Query(). proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)). Where(proxy.StatusEQ(service.StatusActive)).
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, image_count, image_size, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
...@@ -109,6 +109,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -109,6 +109,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
stream, stream,
duration_ms, duration_ms,
first_token_ms, first_token_ms,
user_agent,
ip_address,
image_count, image_count,
image_size, image_size,
created_at created_at
...@@ -118,8 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -118,8 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29
$25, $26, $27
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -129,6 +130,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -129,6 +130,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
subscriptionID := nullInt64(log.SubscriptionID) subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs) duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs) firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize) imageSize := nullString(log.ImageSize)
var requestIDArg any var requestIDArg any
...@@ -161,6 +164,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -161,6 +164,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.Stream, log.Stream,
duration, duration,
firstToken, firstToken,
userAgent,
ipAddress,
log.ImageCount, log.ImageCount,
imageSize, imageSize,
createdAt, createdAt,
...@@ -1388,6 +1393,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT ...@@ -1388,6 +1393,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
return stats, nil return stats, nil
} }
// GetStatsWithFilters gets usage statistics with optional filters
func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) {
conditions := make([]string, 0, 9)
args := make([]any, 0, 9)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.APIKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
args = append(args, filters.AccountID)
}
if filters.GroupID > 0 {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID)
}
if filters.Model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
query := fmt.Sprintf(`
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
if err := scanSingleRow(
ctx,
r.sql,
query,
args,
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
// AccountUsageHistory represents daily usage history for an account // AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory = usagestats.AccountUsageHistory type AccountUsageHistory = usagestats.AccountUsageHistory
...@@ -1795,6 +1875,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1795,6 +1875,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
stream bool stream bool
durationMs sql.NullInt64 durationMs sql.NullInt64
firstTokenMs sql.NullInt64 firstTokenMs sql.NullInt64
userAgent sql.NullString
ipAddress sql.NullString
imageCount int imageCount int
imageSize sql.NullString imageSize sql.NullString
createdAt time.Time createdAt time.Time
...@@ -1826,6 +1908,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1826,6 +1908,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&stream, &stream,
&durationMs, &durationMs,
&firstTokenMs, &firstTokenMs,
&userAgent,
&ipAddress,
&imageCount, &imageCount,
&imageSize, &imageSize,
&createdAt, &createdAt,
...@@ -1877,6 +1961,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1877,6 +1961,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
value := int(firstTokenMs.Int64) value := int(firstTokenMs.Int64)
log.FirstTokenMs = &value log.FirstTokenMs = &value
} }
if userAgent.Valid {
log.UserAgent = &userAgent.String
}
if ipAddress.Valid {
log.IPAddress = &ipAddress.String
}
if imageSize.Valid { if imageSize.Valid {
log.ImageSize = &imageSize.String log.ImageSize = &imageSize.String
} }
......
...@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc ...@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds) return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
} }
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
}
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewUserRepository, NewUserRepository,
...@@ -54,8 +66,8 @@ var ProviderSet = wire.NewSet( ...@@ -54,8 +66,8 @@ var ProviderSet = wire.NewSet(
// HTTP service ports (DI Strategy A: return interface directly) // HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier, NewTurnstileVerifier,
NewPricingRemoteClient, ProvidePricingRemoteClient,
NewGitHubReleaseClient, ProvideGitHubReleaseClient,
NewProxyExitInfoProber, NewProxyExitInfoProber,
NewClaudeUsageFetcher, NewClaudeUsageFetcher,
NewClaudeOAuthClient, NewClaudeOAuthClient,
......
...@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) { ...@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One", "name": "Key One",
"group_id": null, "group_id": null,
"status": "active", "status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
} }
...@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) { ...@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One", "name": "Key One",
"group_id": null, "group_id": null,
"status": "active", "status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
} }
...@@ -243,7 +247,8 @@ func TestAPIContracts(t *testing.T) { ...@@ -243,7 +247,8 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50, "first_token_ms": 50,
"image_count": 0, "image_count": 0,
"image_size": null, "image_size": null,
"created_at": "2025-01-02T03:04:05Z" "created_at": "2025-01-02T03:04:05Z",
"user_agent": null
} }
], ],
"total": 1, "total": 1,
...@@ -303,6 +308,10 @@ func TestAPIContracts(t *testing.T) { ...@@ -303,6 +308,10 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true, "turnstile_enabled": true,
"turnstile_site_key": "site-key", "turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true, "turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"site_name": "Sub2API", "site_name": "Sub2API",
"site_logo": "", "site_logo": "",
"site_subtitle": "Subtitle", "site_subtitle": "Subtitle",
...@@ -393,7 +402,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -393,7 +402,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(cfg, nil, userService) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
...@@ -586,7 +595,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam ...@@ -586,7 +595,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
...@@ -1069,6 +1078,10 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i ...@@ -1069,6 +1078,10 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct { type stubSettingRepo struct {
all map[string]string all map[string]string
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -71,6 +72,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -71,6 +72,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return return
} }
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
}
}
// 检查关联的用户 // 检查关联的用户
if apiKey.User == nil { if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found") AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
......
...@@ -19,6 +19,8 @@ func RegisterAuthRoutes( ...@@ -19,6 +19,8 @@ func RegisterAuthRoutes(
auth.POST("/register", h.Auth.Register) auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
} }
// 公开设置(无需认证) // 公开设置(无需认证)
......
...@@ -9,21 +9,23 @@ import ( ...@@ -9,21 +9,23 @@ import (
) )
type Account struct { type Account struct {
ID int64 ID int64
Name string Name string
Notes *string Notes *string
Platform string Platform string
Type string Type string
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
ProxyID *int64 ProxyID *int64
Concurrency int Concurrency int
Priority int Priority int
Status string Status string
ErrorMessage string ErrorMessage string
LastUsedAt *time.Time LastUsedAt *time.Time
CreatedAt time.Time ExpiresAt *time.Time
UpdatedAt time.Time AutoPauseOnExpired bool
CreatedAt time.Time
UpdatedAt time.Time
Schedulable bool Schedulable bool
...@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool { ...@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool {
return false return false
} }
now := time.Now() now := time.Now()
if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) {
return false
}
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) { if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false return false
} }
......
package service
import (
"context"
"log"
"sync"
"time"
)
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
type AccountExpiryService struct {
accountRepo AccountRepository
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
return &AccountExpiryService{
accountRepo: accountRepo,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *AccountExpiryService) Start() {
if s == nil || s.accountRepo == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *AccountExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *AccountExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
if err != nil {
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
return
}
if updated > 0 {
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
}
}
...@@ -38,6 +38,7 @@ type AccountRepository interface { ...@@ -38,6 +38,7 @@ type AccountRepository interface {
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]Account, error) ListSchedulable(ctx context.Context) ([]Account, error)
...@@ -48,10 +49,12 @@ type AccountRepository interface { ...@@ -48,10 +49,12 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
...@@ -65,35 +68,40 @@ type AccountBulkUpdate struct { ...@@ -65,35 +68,40 @@ type AccountBulkUpdate struct {
Concurrency *int Concurrency *int
Priority *int Priority *int
Status *string Status *string
Schedulable *bool
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
} }
// CreateAccountRequest 创建账号请求 // CreateAccountRequest 创建账号请求
type CreateAccountRequest struct { type CreateAccountRequest struct {
Name string `json:"name"` Name string `json:"name"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
Platform string `json:"platform"` Platform string `json:"platform"`
Type string `json:"type"` Type string `json:"type"`
Credentials map[string]any `json:"credentials"` Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
Priority int `json:"priority"` Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
} }
// UpdateAccountRequest 更新账号请求 // UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"` Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"` Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"` Priority *int `json:"priority"`
Status *string `json:"status"` Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
} }
// AccountService 账号管理服务 // AccountService 账号管理服务
...@@ -134,6 +142,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -134,6 +142,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Priority: req.Priority, Priority: req.Priority,
Status: StatusActive, Status: StatusActive,
ExpiresAt: req.ExpiresAt,
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
} }
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
...@@ -224,6 +238,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount ...@@ -224,6 +238,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if req.Status != nil { if req.Status != nil {
account.Status = *req.Status account.Status = *req.Status
} }
if req.ExpiresAt != nil {
account.ExpiresAt = req.ExpiresAt
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前) // 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil { if req.GroupIDs != nil {
......
...@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula ...@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula
panic("unexpected SetSchedulable call") panic("unexpected SetSchedulable call")
} }
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
panic("unexpected AutoPauseExpiredAccounts call")
}
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
panic("unexpected BindGroups call") panic("unexpected BindGroups call")
} }
...@@ -135,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt ...@@ -135,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call") panic("unexpected SetRateLimited call")
} }
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call") panic("unexpected SetOverloaded call")
} }
...@@ -151,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error { ...@@ -151,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call") panic("unexpected ClearRateLimit call")
} }
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
panic("unexpected ClearAntigravityQuotaScopes call")
}
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call") panic("unexpected UpdateSessionWindow call")
} }
......
...@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) ...@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
} }
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 { if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
if candidate, ok := candidates[0].(map[string]any); ok { if candidate, ok := candidates[0].(map[string]any); ok {
// Check for completion // Extract content first (before checking completion)
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// Extract content
if content, ok := candidate["content"].(map[string]any); ok { if content, ok := candidate["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok { if parts, ok := content["parts"].([]any); ok {
for _, part := range parts { for _, part := range parts {
...@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) ...@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
} }
} }
} }
// Check for completion after extracting content
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment