Commit 3a67002c authored by IanShaw027's avatar IanShaw027
Browse files

merge: 合并主分支改动并保留 ops 监控实现

合并 main 分支的最新改动到 ops 监控分支。
冲突解决策略:保留当前分支的 ops 相关改动,接受主分支的其他改动。

保留的 ops 改动:
- 运维监控配置和依赖注入
- 运维监控 API 处理器和中间件
- 运维监控服务层和数据访问层
- 运维监控前端界面和状态管理

接受的主分支改动:
- Linux DO OAuth 集成
- 账号过期功能
- IP 地址限制功能
- 用量统计优化
- 其他 bug 修复和功能改进
parents c48dc097 7d1fe818
......@@ -2,6 +2,7 @@ package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb}
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
key := stickySessionPrefix + sessionHash
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}
......@@ -24,18 +24,19 @@ func (s *GatewayCacheSuite) SetupTest() {
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
_, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
}
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1"
accountID := int64(99)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch")
}
......@@ -43,11 +44,12 @@ func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2"
accountID := int64(100)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := stickySessionPrefix + sessionID
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
......@@ -56,14 +58,15 @@ func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3"
accountID := int64(101)
groupID := int64(1)
initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := stickySessionPrefix + sessionID
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
......@@ -71,18 +74,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
sessionKey := stickySessionPrefix + sessionID
groupID := int64(1)
sessionKey := buildSessionKey(groupID, sessionID)
// Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
......
......@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client
// - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
......@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
......
......@@ -15,22 +15,32 @@ import (
type githubReleaseClient struct {
httpClient *http.Client
allowPrivateHosts bool
downloadHTTPClient *http.Client
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
allowPrivate := false
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
ProxyURL: proxyURL,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
// 下载客户端需要更长的超时时间
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ProxyURL: proxyURL,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
return &githubReleaseClient{
httpClient: sharedClient,
allowPrivateHosts: allowPrivate,
downloadHTTPClient: downloadClient,
}
}
......@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return err
}
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
resp, err := downloadClient.Do(req)
// 使用预配置的下载客户端(已包含代理配置)
resp, err := c.downloadHTTPClient.Do(req)
if err != nil {
return err
}
......
......@@ -40,7 +40,7 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
}
......@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
......@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
......@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
......@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
ctx, cancel := context.WithCancel(context.Background())
......
......@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays)
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
created, err := builder.Save(ctx)
if err == nil {
......@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
}
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
updated, err := r.client.Group.UpdateOneID(groupIn.ID).
builder := r.client.Group.UpdateOneID(groupIn.ID).
SetName(groupIn.Name).
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
......@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx)
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
} else {
builder = builder.ClearFallbackGroupID()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
}
......@@ -101,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
return r.ListWithFilters(ctx, params, "", "", "", nil)
}
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
q := r.client.Group.Query()
if platform != "" {
......@@ -113,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
if status != "" {
q = q.Where(group.StatusEQ(status))
}
if search != "" {
q = q.Where(group.Or(
group.NameContainsFold(search),
group.DescriptionContainsFold(search),
))
}
if isExclusive != nil {
q = q.Where(group.IsExclusiveEQ(*isExclusive))
}
......
......@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
pagination.PaginationParams{Page: 1, PageSize: 10},
service.PlatformOpenAI,
"",
"",
nil,
)
s.Require().NoError(err, "ListWithFilters base")
......@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
s.Require().NoError(err)
s.Require().Len(groups, len(baseGroups)+1)
// Verify all groups are OpenAI platform
......@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(service.StatusDisabled, groups[0].Status)
......@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
}))
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive)
}
func (s *GroupRepoSuite) TestListWithFilters_Search() {
newRepo := func() (*groupRepository, context.Context) {
tx := testEntTx(s.T())
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
}
containsID := func(groups []service.Group, id int64) bool {
for i := range groups {
if groups[i].ID == id {
return true
}
}
return false
}
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
s.Require().NoError(repo.Create(ctx, g))
s.Require().NotZero(g.ID)
return g
}
newGroup := func(name string) *service.Group {
return &service.Group{
Name: name,
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
}
s.Run("search_name_should_match", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_description_should_match", func() {
repo, ctx := newRepo()
target := newGroup("it-group-search-desc-target")
target.Description = "something about desc-needle in here"
target = mustCreate(repo, ctx, target)
other := newGroup("it-group-search-desc-other")
other.Description = "nothing to see here"
other = mustCreate(repo, ctx, other)
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_nonexistent_should_return_empty", func() {
repo, ctx := newRepo()
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
search := s.T().Name() + "__no_such_group__"
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
s.Require().NoError(err)
s.Require().Empty(groups)
})
s.Run("search_should_be_case_insensitive", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_should_escape_like_wildcards", func() {
repo, ctx := newRepo()
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
})
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{
Name: "g1",
......@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
s.Require().NoError(err)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1)
......
......@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
......@@ -17,17 +16,12 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
// NewPricingRemoteClient 创建定价数据远程客户端
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ValidateResolvedIP: validateResolvedIP,
AllowPrivateHosts: allowPrivate,
ProxyURL: proxyURL,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
......
......@@ -6,7 +6,6 @@ import (
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
......@@ -20,13 +19,7 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient(&config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
......
......@@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return outProxies, paginationResultFromTotal(int64(total), params), nil
}
// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy
func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
q := r.client.Proxy.Query()
if protocol != "" {
q = q.Where(proxy.ProtocolEQ(protocol))
}
if status != "" {
q = q.Where(proxy.StatusEQ(status))
}
if search != "" {
q = q.Where(proxy.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
proxies, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, nil, err
}
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
if proxyOut == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxyOut,
AccountCount: counts[proxyOut.ID],
})
}
return result, paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)).
......
......@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, image_count, image_size, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct {
client *dbent.Client
......@@ -109,6 +109,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
stream,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
created_at
......@@ -118,8 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24,
$25, $26, $27
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
......@@ -129,6 +130,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
var requestIDArg any
......@@ -161,6 +164,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.Stream,
duration,
firstToken,
userAgent,
ipAddress,
log.ImageCount,
imageSize,
createdAt,
......@@ -1388,6 +1393,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
return stats, nil
}
// GetStatsWithFilters gets usage statistics with optional filters
func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) {
conditions := make([]string, 0, 9)
args := make([]any, 0, 9)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.APIKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
args = append(args, filters.AccountID)
}
if filters.GroupID > 0 {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID)
}
if filters.Model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
query := fmt.Sprintf(`
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
if err := scanSingleRow(
ctx,
r.sql,
query,
args,
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory = usagestats.AccountUsageHistory
......@@ -1795,6 +1875,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
......@@ -1826,6 +1908,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&stream,
&durationMs,
&firstTokenMs,
&userAgent,
&ipAddress,
&imageCount,
&imageSize,
&createdAt,
......@@ -1877,6 +1961,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
value := int(firstTokenMs.Int64)
log.FirstTokenMs = &value
}
if userAgent.Valid {
log.UserAgent = &userAgent.String
}
if ipAddress.Valid {
log.IPAddress = &ipAddress.String
}
if imageSize.Valid {
log.ImageSize = &imageSize.String
}
......
......@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
}
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
}
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
......@@ -54,8 +66,8 @@ var ProviderSet = wire.NewSet(
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
NewPricingRemoteClient,
NewGitHubReleaseClient,
ProvidePricingRemoteClient,
ProvideGitHubReleaseClient,
NewProxyExitInfoProber,
NewClaudeUsageFetcher,
NewClaudeOAuthClient,
......
......@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One",
"group_id": null,
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
......@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One",
"group_id": null,
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
......@@ -243,7 +247,8 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50,
"image_count": 0,
"image_size": null,
"created_at": "2025-01-02T03:04:05Z"
"created_at": "2025-01-02T03:04:05Z",
"user_agent": null
}
],
"total": 1,
......@@ -303,6 +308,10 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
......@@ -393,7 +402,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(cfg, nil, userService)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
......@@ -586,7 +595,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam
return nil, nil, errors.New("not implemented")
}
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
......@@ -1069,6 +1078,10 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct {
all map[string]string
}
......
......@@ -6,6 +6,7 @@ import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
......@@ -71,6 +72,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
}
}
// 检查关联的用户
if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
......
......@@ -19,6 +19,8 @@ func RegisterAuthRoutes(
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
}
// 公开设置(无需认证)
......
......@@ -22,6 +22,8 @@ type Account struct {
Status string
ErrorMessage string
LastUsedAt *time.Time
ExpiresAt *time.Time
AutoPauseOnExpired bool
CreatedAt time.Time
UpdatedAt time.Time
......@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool {
return false
}
now := time.Now()
if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) {
return false
}
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
......
package service
import (
"context"
"log"
"sync"
"time"
)
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
type AccountExpiryService struct {
accountRepo AccountRepository
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
return &AccountExpiryService{
accountRepo: accountRepo,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *AccountExpiryService) Start() {
if s == nil || s.accountRepo == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *AccountExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *AccountExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
if err != nil {
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
return
}
if updated > 0 {
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
}
}
......@@ -38,6 +38,7 @@ type AccountRepository interface {
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]Account, error)
......@@ -48,10 +49,12 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
......@@ -65,6 +68,7 @@ type AccountBulkUpdate struct {
Concurrency *int
Priority *int
Status *string
Schedulable *bool
Credentials map[string]any
Extra map[string]any
}
......@@ -81,6 +85,8 @@ type CreateAccountRequest struct {
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
}
// UpdateAccountRequest 更新账号请求
......@@ -94,6 +100,8 @@ type UpdateAccountRequest struct {
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
}
// AccountService 账号管理服务
......@@ -134,6 +142,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: StatusActive,
ExpiresAt: req.ExpiresAt,
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
}
if err := s.accountRepo.Create(ctx, account); err != nil {
......@@ -224,6 +238,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if req.Status != nil {
account.Status = *req.Status
}
if req.ExpiresAt != nil {
account.ExpiresAt = req.ExpiresAt
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil {
......
......@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula
panic("unexpected SetSchedulable call")
}
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
panic("unexpected AutoPauseExpiredAccounts call")
}
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
panic("unexpected BindGroups call")
}
......@@ -135,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call")
}
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call")
}
......@@ -151,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call")
}
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
panic("unexpected ClearAntigravityQuotaScopes call")
}
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call")
}
......
......@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
if candidate, ok := candidates[0].(map[string]any); ok {
// Check for completion
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// Extract content
// Extract content first (before checking completion)
if content, ok := candidate["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok {
for _, part := range parts {
......@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
}
}
// Check for completion after extracting content
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment