Commit 7dddd065 authored by yangjianbo's avatar yangjianbo
Browse files
parents 25a0d49a e78c8646
......@@ -3,9 +3,9 @@ package repository
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
......@@ -18,7 +18,6 @@ import (
type TurnstileServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
verifier *turnstileVerifier
received chan url.Values
}
......@@ -31,21 +30,15 @@ func (s *TurnstileServiceSuite) SetupTest() {
s.verifier = verifier
}
func (s *TurnstileServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) {
s.verifier.verifyURL = "http://in-process/turnstile"
s.verifier.httpClient = &http.Client{
Transport: newInProcessTransport(handler, nil),
}
}
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.verifier.verifyURL = s.srv.URL
s.verifier.httpClient = s.srv.Client()
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture form data in main goroutine context later
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
......@@ -73,7 +66,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
var contentType string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType = r.Header.Get("Content-Type")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
......@@ -85,7 +78,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
}
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
......@@ -106,15 +99,19 @@ func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
}
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
s.verifier.verifyURL = "http://in-process/turnstile"
s.verifier.httpClient = &http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("dial failed")
}),
}
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error when server is closed")
}
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-valid-json")
}))
......@@ -124,7 +121,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
}
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
Success: false,
......
......@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"strings"
......@@ -60,9 +61,16 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
return requestCount / 5, tokenCount / 5, nil
}
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
if log == nil {
return nil
return false, nil
}
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
// 无事务时回退到默认的 *sql.DB 执行器。
sqlq := r.sql
if tx := dbent.TxFromContext(ctx); tx != nil {
sqlq = tx.Client()
}
createdAt := log.CreatedAt
......@@ -70,6 +78,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
createdAt = time.Now()
}
requestID := strings.TrimSpace(log.RequestID)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier
query := `
......@@ -107,6 +118,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
`
......@@ -115,11 +127,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
var requestIDArg any
if requestID != "" {
requestIDArg = requestID
}
args := []any{
log.UserID,
log.ApiKeyID,
log.APIKeyID,
log.AccountID,
log.RequestID,
requestIDArg,
log.Model,
groupID,
subscriptionID,
......@@ -142,11 +159,20 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
firstToken,
createdAt,
}
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
return err
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
return false, err
}
log.RateMultiplier = rateMultiplier
return nil
return false, nil
} else {
return false, err
}
}
log.RateMultiplier = rateMultiplier
return true, nil
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
......@@ -183,7 +209,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
}
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
}
......@@ -270,8 +296,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
&stats.TotalAPIKeys,
&stats.ActiveAPIKeys,
); err != nil {
return nil, err
}
......@@ -418,8 +444,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
return &stats, nil
}
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
......@@ -623,7 +649,7 @@ func resolveUsageStatsTimezone() string {
return "UTC"
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
......@@ -709,11 +735,11 @@ type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
......@@ -755,10 +781,10 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
}
}()
results = make([]ApiKeyUsageTrendPoint, 0)
results = make([]APIKeyUsageTrendPoint, 0)
for rows.Next() {
var row ApiKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
var row APIKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err
}
results = append(results, row)
......@@ -844,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalApiKeys,
&stats.TotalAPIKeys,
); err != nil {
return nil, err
}
......@@ -853,7 +879,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveApiKeys,
&stats.ActiveAPIKeys,
); err != nil {
return nil, err
}
......@@ -1023,9 +1049,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.ApiKeyID > 0 {
if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.ApiKeyID)
args = append(args, filters.APIKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
......@@ -1145,18 +1171,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
return result, nil
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
result := make(map[int64]*BatchApiKeyUsageStats)
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
for _, id := range apiKeyIDs {
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
}
query := `
......@@ -1582,7 +1608,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if err != nil {
return err
}
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
if err != nil {
return err
}
......@@ -1603,8 +1629,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if user, ok := users[logs[i].UserID]; ok {
logs[i].User = user
}
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
logs[i].ApiKey = key
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
logs[i].APIKey = key
}
if acc, ok := accounts[logs[i].AccountID]; ok {
logs[i].Account = acc
......@@ -1642,7 +1668,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
for i := range logs {
userIDs[logs[i].UserID] = struct{}{}
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
accountIDs[logs[i].AccountID] = struct{}{}
if logs[i].GroupID != nil {
groupIDs[*logs[i].GroupID] = struct{}{}
......@@ -1676,12 +1702,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
return out, nil
}
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
out := make(map[int64]*service.ApiKey)
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
out := make(map[int64]*service.APIKey)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
......@@ -1800,7 +1826,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
log := &service.UsageLog{
ID: id,
UserID: userID,
ApiKeyID: apiKeyID,
APIKeyID: apiKeyID,
AccountID: accountID,
Model: model,
InputTokens: inputTokens,
......
......@@ -7,6 +7,8 @@ import (
"testing"
"time"
"github.com/google/uuid"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
......@@ -35,11 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(), // Generate unique RequestID for each log
Model: "claude-3",
InputTokens: inputTokens,
OutputTokens: outputTokens,
......@@ -47,7 +50,8 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
ActualCost: cost,
CreatedAt: createdAt,
}
s.Require().NoError(s.repo.Create(s.ctx, log))
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
return log
}
......@@ -55,12 +59,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
......@@ -69,14 +73,14 @@ func (s *UsageLogRepoSuite) TestCreate() {
ActualCost: 0.4,
}
err := s.repo.Create(s.ctx, log)
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err, "Create")
s.Require().NotZero(log.ID)
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
......@@ -96,7 +100,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
......@@ -112,7 +116,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
......@@ -124,18 +128,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
s.Require().Equal(int64(2), page.Total)
}
// --- ListByApiKey ---
// --- ListByAPIKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
func (s *UsageLogRepoSuite) TestListByAPIKey() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByApiKey")
logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByAPIKey")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
......@@ -144,7 +148,7 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
......@@ -159,7 +163,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -179,7 +183,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
......@@ -211,8 +215,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
......@@ -223,7 +227,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
d1, d2, d3 := 100, 200, 300
logToday := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
GroupID: &group.ID,
......@@ -236,11 +240,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
DurationMs: &d1,
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
}
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
_, err = s.repo.Create(s.ctx, logToday)
s.Require().NoError(err, "Create logToday")
logOld := &service.UsageLog{
UserID: userOld.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 5,
......@@ -250,11 +255,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
DurationMs: &d2,
CreatedAt: todayStart.Add(-1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
_, err = s.repo.Create(s.ctx, logOld)
s.Require().NoError(err, "Create logOld")
logPerf := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 1,
......@@ -264,7 +270,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
DurationMs: &d3,
CreatedAt: now.Add(-30 * time.Second),
}
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf")
_, err = s.repo.Create(s.ctx, logPerf)
s.Require().NoError(err, "Create logPerf")
stats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats")
......@@ -272,8 +279,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch")
s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch")
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
......@@ -300,14 +307,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalApiKeys)
s.Require().Equal(int64(1), stats.TotalAPIKeys)
s.Require().Equal(int64(1), stats.TotalRequests)
}
......@@ -315,7 +322,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
......@@ -331,8 +338,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
......@@ -351,24 +358,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
s.Require().Empty(stats)
}
// --- GetBatchApiKeyUsageStats ---
// --- GetBatchAPIKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
......@@ -377,7 +384,7 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -402,7 +409,7 @@ func maxTime(a, b time.Time) time.Time {
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -417,11 +424,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
s.Require().Len(logs, 2)
}
// --- ListByApiKeyAndTimeRange ---
// --- ListByAPIKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -431,8 +438,8 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByAPIKeyAndTimeRange")
s.Require().Len(logs, 2)
}
......@@ -440,7 +447,7 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -459,7 +466,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -467,7 +474,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 10,
......@@ -476,11 +483,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
_, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 15,
......@@ -489,11 +497,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
ActualCost: 0.6,
CreatedAt: base.Add(30 * time.Minute),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
_, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
log3 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 20,
......@@ -502,7 +511,8 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
ActualCost: 0.7,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log3))
_, err = s.repo.Create(s.ctx, log3)
s.Require().NoError(err)
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
......@@ -515,7 +525,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
now := time.Now()
......@@ -535,7 +545,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -552,7 +562,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -571,7 +581,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -579,7 +589,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
......@@ -588,11 +598,12 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
_, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
......@@ -601,7 +612,8 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
_, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
......@@ -618,7 +630,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -646,7 +658,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -665,14 +677,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
......@@ -681,11 +693,12 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
_, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
......@@ -694,7 +707,8 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
_, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
......@@ -719,7 +733,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
......@@ -727,7 +741,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
// Create logs on different days
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
......@@ -736,11 +750,12 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
ActualCost: 0.4,
CreatedAt: base.Add(12 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
_, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
......@@ -749,7 +764,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
ActualCost: 0.15,
CreatedAt: base.Add(36 * time.Hour), // next day
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
_, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
startTime := base
endTime := base.Add(72 * time.Hour)
......@@ -782,8 +798,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -799,12 +815,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
s.Require().GreaterOrEqual(len(trend), 2)
}
// --- GetApiKeyUsageTrend ---
// --- GetAPIKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -815,14 +831,14 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend")
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -832,8 +848,8 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
s.Require().Len(trend, 2)
}
......@@ -841,12 +857,12 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1)
......@@ -855,7 +871,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -874,7 +890,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
......@@ -885,7 +901,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
StartTime: &startTime,
EndTime: &endTime,
}
......
......@@ -4,12 +4,13 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -17,14 +18,15 @@ import (
type userRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB)
}
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
return &userRepository{client: client}
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client, sql: sqlq}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
......@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
var attrErr error
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
if attrErr != nil {
return nil, nil, attrErr
}
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
......@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
return nil
return nil, nil
}
// For each attribute filter, get the set of matching user IDs
// Then intersect all sets to get users matching ALL filters
var resultSet map[int64]struct{}
first := true
for attrID, value := range attrs {
// Query user_attribute_values for this attribute
values, err := r.client.UserAttributeValue.Query().
Where(
userattributevalue.AttributeIDEQ(attrID),
userattributevalue.ValueContainsFold(value),
).
All(ctx)
if err != nil {
continue
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
currentSet := make(map[int64]struct{}, len(values))
for _, v := range values {
currentSet[v.UserID] = struct{}{}
}
clauses := make([]string, 0, len(attrs))
args := make([]any, 0, len(attrs)*2+1)
argIndex := 1
for attrID, value := range attrs {
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
args = append(args, attrID, "%"+value+"%")
argIndex += 2
}
query := fmt.Sprintf(
`SELECT user_id
FROM user_attribute_values
WHERE %s
GROUP BY user_id
HAVING COUNT(DISTINCT attribute_id) = $%d`,
strings.Join(clauses, " OR "),
argIndex,
)
args = append(args, len(attrs))
if first {
resultSet = currentSet
first = false
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
result := make([]int64, 0)
for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
return nil, scanErr
}
result := make([]int64, 0, len(resultSet))
for userID := range resultSet {
result = append(result, userID)
}
return result
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
......
......@@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewApiKeyRepository,
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
......@@ -42,7 +42,8 @@ var ProviderSet = wire.NewSet(
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewApiKeyCache,
NewAPIKeyCache,
NewTempUnschedCache,
ProvideConcurrencyCache,
NewEmailCache,
NewIdentityCache,
......
......@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.apiKeyRepo.MustSeed(&service.ApiKey{
deps.apiKeyRepo.MustSeed(&service.APIKey{
ID: 100,
UserID: 1,
Key: "sk_custom_1234567890",
......@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
APIKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 10,
......@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 2,
UserID: 1,
ApiKeyID: 100,
APIKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 5,
......@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
APIKeyID: 100,
AccountID: 200,
RequestID: "req_123",
Model: "claude-3",
......@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySmtpHost: "smtp.example.com",
service.SettingKeySmtpPort: "587",
service.SettingKeySmtpUsername: "user",
service.SettingKeySmtpPassword: "secret",
service.SettingKeySmtpFrom: "no-reply@example.com",
service.SettingKeySmtpFromName: "Sub2API",
service.SettingKeySmtpUseTLS: "true",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
service.SettingKeySMTPUsername: "user",
service.SettingKeySMTPPassword: "secret",
service.SettingKeySMTPFrom: "no-reply@example.com",
service.SettingKeySMTPFromName: "Sub2API",
service.SettingKeySMTPUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key",
......@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyApiBaseUrl: "https://api.example.com",
service.SettingKeyAPIBaseURL: "https://api.example.com",
service.SettingKeyContactInfo: "support",
service.SettingKeyDocUrl: "https://docs.example.com",
service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
......@@ -308,7 +308,12 @@ func TestAPIContracts(t *testing.T) {
"contact_info": "support",
"doc_url": "https://docs.example.com",
"default_concurrency": 5,
"default_balance": 1.25
"default_balance": 1.25,
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro",
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o"
}
}`,
},
......@@ -366,16 +371,16 @@ func newContractDeps(t *testing.T) *contractDeps {
cfg := &config.Config{
Default: config.DefaultConfig{
ApiKeyPrefix: "sk-",
APIKeyPrefix: "sk-",
},
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo)
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo)
usageService := service.NewUsageService(usageRepo, userRepo, nil)
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
......@@ -664,20 +669,20 @@ type stubApiKeyRepo struct {
now time.Time
nextID int64
byID map[int64]*service.ApiKey
byKey map[string]*service.ApiKey
byID map[int64]*service.APIKey
byKey map[string]*service.APIKey
}
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
return &stubApiKeyRepo{
now: now,
nextID: 100,
byID: make(map[int64]*service.ApiKey),
byKey: make(map[string]*service.ApiKey),
byID: make(map[int64]*service.APIKey),
byKey: make(map[string]*service.APIKey),
}
}
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
func (r *stubApiKeyRepo) MustSeed(key *service.APIKey) {
if key == nil {
return
}
......@@ -686,7 +691,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
r.byKey[clone.Key] = &clone
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
if key == nil {
return errors.New("nil key")
}
......@@ -706,10 +711,10 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error
return nil
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
key, ok := r.byID[id]
if !ok {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
clone := *key
return &clone, nil
......@@ -718,26 +723,26 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
key, ok := r.byID[id]
if !ok {
return 0, service.ErrApiKeyNotFound
return 0, service.ErrAPIKeyNotFound
}
return key.UserID, nil
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
found, ok := r.byKey[key]
if !ok {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
clone := *found
return &clone, nil
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
if key == nil {
return errors.New("nil key")
}
if _, ok := r.byID[key.ID]; !ok {
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now
......@@ -751,14 +756,14 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id]
if !ok {
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
delete(r.byID, id)
delete(r.byKey, key.Key)
return nil
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID))
for id := range r.byID {
if r.byID[id].UserID == userID {
......@@ -776,7 +781,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
end = len(ids)
}
out := make([]service.ApiKey, 0, end-start)
out := make([]service.APIKey, 0, end-start)
for _, id := range ids[start:end] {
clone := *r.byID[id]
out = append(out, clone)
......@@ -830,11 +835,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
return ok, nil
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented")
}
......@@ -858,8 +863,8 @@ func (r *stubUsageLogRepo) SetUserLogs(userID int64, logs []service.UsageLog) {
r.userLogs[userID] = logs
}
func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) error {
return errors.New("not implemented")
func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
......@@ -877,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
return out, paginationResult(total, params), nil
}
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
......@@ -890,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
}
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
......@@ -922,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
......@@ -975,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
}, nil
}
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
......@@ -995,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented")
}
......@@ -1017,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
// Apply filters
var filtered []service.UsageLog
for _, log := range logs {
// Apply ApiKeyID filter
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID {
// Apply APIKeyID filter
if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID {
continue
}
// Apply Model filter
......@@ -1151,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
// Ensure compile-time interface compliance.
var (
_ service.UserRepository = (*stubUserRepo)(nil)
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil)
_ service.ApiKeyCache = (*stubApiKeyCache)(nil)
_ service.APIKeyRepository = (*stubApiKeyRepo)(nil)
_ service.APIKeyCache = (*stubApiKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
......
// Package server provides HTTP server initialization and configuration.
package server
import (
......@@ -26,8 +27,8 @@ func ProvideRouter(
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
) *gin.Engine {
if cfg.Server.Mode == "release" {
......
// Package middleware provides HTTP middleware for authentication, authorization, and request processing.
package middleware
import (
......@@ -32,7 +33,7 @@ func adminAuth(
// 检查 x-api-key header(Admin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userService) {
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return
}
c.Next()
......@@ -57,14 +58,14 @@ func adminAuth(
}
}
// validateAdminApiKey 验证管理员 API Key
func validateAdminApiKey(
// validateAdminAPIKey 验证管理员 API Key
func validateAdminAPIKey(
c *gin.Context,
key string,
settingService *service.SettingService,
userService *service.UserService,
) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false
......
......@@ -11,13 +11,13 @@ import (
"github.com/gin-gonic/gin"
)
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) APIKeyAuthMiddleware {
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
}
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
queryKey := strings.TrimSpace(c.Query("key"))
queryApiKey := strings.TrimSpace(c.Query("api_key"))
......@@ -57,7 +57,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
if errors.Is(err, service.ErrAPIKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
......@@ -85,7 +85,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
......@@ -143,7 +143,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
}
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
......@@ -154,13 +154,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
}
}
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
// GetAPIKeyFromContext 从上下文中获取API key
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
value, exists := c.Get(string(ContextKeyAPIKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*service.ApiKey)
apiKey, ok := value.(*service.APIKey)
return apiKey, ok
}
......
......@@ -11,16 +11,16 @@ import (
"github.com/gin-gonic/gin"
)
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
// APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
}
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// APIKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
//
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
......@@ -34,7 +34,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
if errors.Is(err, service.ErrAPIKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key")
return
}
......@@ -57,7 +57,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
// 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
......@@ -96,7 +96,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
}
}
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
......
......@@ -16,53 +16,53 @@ import (
"github.com/stretchr/testify/require"
)
type fakeApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
type fakeAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
}
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if f.getByKey == nil {
return nil, errors.New("unexpected call")
}
return f.getByKey(ctx, key)
}
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
......@@ -74,8 +74,8 @@ type googleErrorResponse struct {
} `json:"error"`
}
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
return service.NewApiKeyService(
func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
return service.NewAPIKeyService(
repo,
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
......@@ -89,12 +89,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
......@@ -165,12 +165,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, service.ErrApiKeyNotFound
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, service.ErrAPIKeyNotFound
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
......@@ -190,12 +190,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("db down")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
......@@ -215,9 +215,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusDisabled,
......@@ -228,7 +228,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
......@@ -248,9 +248,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusActive,
......@@ -262,7 +262,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
......
......@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
Balance: 10,
Concurrency: 3,
}
apiKey := &service.ApiKey{
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
......@@ -46,9 +46,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
......@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
......@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
......@@ -110,9 +110,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
......@@ -120,14 +120,14 @@ func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService
}
type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented")
}
......@@ -135,14 +135,14 @@ func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if r.getByKey != nil {
return r.getByKey(ctx, key)
}
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
......@@ -150,7 +150,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
......@@ -166,11 +166,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
return false, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented")
}
......
......@@ -15,8 +15,8 @@ const (
ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色(string)
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeyAPIKey API密钥上下文键
ContextKeyAPIKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
......
......@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc
// APIKeyAuthMiddleware API Key 认证中间件类型
type APIKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware,
NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware,
NewAPIKeyAuthMiddleware,
)
......@@ -17,8 +17,8 @@ func SetupRouter(
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) *gin.Engine {
......@@ -44,8 +44,8 @@ func registerRoutes(
h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
......
// Package routes provides HTTP route registration and handlers.
package routes
import (
......@@ -67,10 +68,10 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
}
}
......@@ -123,6 +124,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
......@@ -203,12 +206,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
}
}
......@@ -248,7 +251,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
}
}
......
......@@ -13,8 +13,8 @@ import (
func RegisterGatewayRoutes(
r *gin.Engine,
h *handler.Handlers,
apiKeyAuth middleware.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
......@@ -36,7 +36,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
......@@ -65,7 +65,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
......
......@@ -50,7 +50,7 @@ func RegisterUserRoutes(
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
}
// 卡密兑换
......
// Package service provides business logic and domain services for the application.
package service
import (
......@@ -29,6 +30,9 @@ type Account struct {
RateLimitResetAt *time.Time
OverloadUntil *time.Time
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
......@@ -39,6 +43,13 @@ type Account struct {
Groups []*Group
}
type TempUnschedulableRule struct {
ErrorCode int `json:"error_code"`
Keywords []string `json:"keywords"`
DurationMinutes int `json:"duration_minutes"`
Description string `json:"description"`
}
func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
......@@ -54,6 +65,9 @@ func (a *Account) IsSchedulable() bool {
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
return true
}
......@@ -92,10 +106,7 @@ func (a *Account) GeminiOAuthType() string {
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
return tierID
}
func (a *Account) IsGeminiCodeAssist() bool {
......@@ -163,6 +174,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil
}
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false
}
raw, ok := a.Credentials["temp_unschedulable_enabled"]
if !ok || raw == nil {
return false
}
enabled, ok := raw.(bool)
return ok && enabled
}
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["temp_unschedulable_rules"]
if !ok || raw == nil {
return nil
}
arr, ok := raw.([]any)
if !ok {
return nil
}
rules := make([]TempUnschedulableRule, 0, len(arr))
for _, item := range arr {
entry, ok := item.(map[string]any)
if !ok || entry == nil {
continue
}
rule := TempUnschedulableRule{
ErrorCode: parseTempUnschedInt(entry["error_code"]),
Keywords: parseTempUnschedStrings(entry["keywords"]),
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
Description: parseTempUnschedString(entry["description"]),
}
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
continue
}
rules = append(rules, rule)
}
return rules
}
func parseTempUnschedString(value any) string {
s, ok := value.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func parseTempUnschedStrings(value any) []string {
if value == nil {
return nil
}
var raw []string
switch v := value.(type) {
case []string:
raw = v
case []any:
raw = make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
raw = append(raw, s)
}
}
default:
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
s := strings.TrimSpace(item)
if s != "" {
out = append(out, s)
}
}
return out
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
......@@ -206,7 +325,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
if a.Type != AccountTypeAPIKey {
return ""
}
baseURL := a.GetCredential("base_url")
......@@ -229,7 +348,7 @@ func (a *Account) GetExtraString(key string) string {
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
......@@ -301,14 +420,14 @@ func (a *Account) IsOpenAIOAuth() bool {
}
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
if a.Type == AccountTypeAPIKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
......
......@@ -49,6 +49,8 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment