Commit c86d445c authored by IanShaw027's avatar IanShaw027
Browse files

fix(frontend): sync with main and finalize i18n & component optimizations

parents 6c036d7b e78c8646
...@@ -3,9 +3,9 @@ package repository ...@@ -3,9 +3,9 @@ package repository
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
...@@ -18,7 +18,6 @@ import ( ...@@ -18,7 +18,6 @@ import (
type TurnstileServiceSuite struct { type TurnstileServiceSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
srv *httptest.Server
verifier *turnstileVerifier verifier *turnstileVerifier
received chan url.Values received chan url.Values
} }
...@@ -31,20 +30,15 @@ func (s *TurnstileServiceSuite) SetupTest() { ...@@ -31,20 +30,15 @@ func (s *TurnstileServiceSuite) SetupTest() {
s.verifier = verifier s.verifier = verifier
} }
func (s *TurnstileServiceSuite) TearDownTest() { func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) {
if s.srv != nil { s.verifier.verifyURL = "http://in-process/turnstile"
s.srv.Close() s.verifier.httpClient = &http.Client{
s.srv = nil Transport: newInProcessTransport(handler, nil),
} }
} }
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.verifier.verifyURL = s.srv.URL
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() { func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture form data in main goroutine context later // Capture form data in main goroutine context later
body, _ := io.ReadAll(r.Body) body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body)) values, _ := url.ParseQuery(string(body))
...@@ -72,7 +66,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() { ...@@ -72,7 +66,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() { func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
var contentType string var contentType string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType = r.Header.Get("Content-Type") contentType = r.Header.Get("Content-Type")
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true}) _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
...@@ -84,7 +78,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() { ...@@ -84,7 +78,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
} }
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() { func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body) body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body)) values, _ := url.ParseQuery(string(body))
s.received <- values s.received <- values
...@@ -105,15 +99,19 @@ func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() { ...@@ -105,15 +99,19 @@ func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
} }
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() { func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) s.verifier.verifyURL = "http://in-process/turnstile"
s.srv.Close() s.verifier.httpClient = &http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("dial failed")
}),
}
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1") _, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error when server is closed") require.Error(s.T(), err, "expected error when server is closed")
} }
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() { func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-valid-json") _, _ = io.WriteString(w, "not-valid-json")
})) }))
...@@ -123,7 +121,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() { ...@@ -123,7 +121,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
} }
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() { func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{ _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
Success: false, Success: false,
......
...@@ -3,6 +3,7 @@ package repository ...@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"os" "os"
"strings" "strings"
...@@ -60,9 +61,16 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int ...@@ -60,9 +61,16 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
return requestCount / 5, tokenCount / 5, nil return requestCount / 5, tokenCount / 5, nil
} }
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error { func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
if log == nil { if log == nil {
return nil return false, nil
}
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
// 无事务时回退到默认的 *sql.DB 执行器。
sqlq := r.sql
if tx := dbent.TxFromContext(ctx); tx != nil {
sqlq = tx.Client()
} }
createdAt := log.CreatedAt createdAt := log.CreatedAt
...@@ -70,6 +78,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -70,6 +78,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
createdAt = time.Now() createdAt = time.Now()
} }
requestID := strings.TrimSpace(log.RequestID)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier rateMultiplier := log.RateMultiplier
query := ` query := `
...@@ -107,6 +118,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -107,6 +118,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25 $20, $21, $22, $23, $24, $25
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
` `
...@@ -115,11 +127,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -115,11 +127,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs) duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs) firstToken := nullInt(log.FirstTokenMs)
var requestIDArg any
if requestID != "" {
requestIDArg = requestID
}
args := []any{ args := []any{
log.UserID, log.UserID,
log.ApiKeyID, log.APIKeyID,
log.AccountID, log.AccountID,
log.RequestID, requestIDArg,
log.Model, log.Model,
groupID, groupID,
subscriptionID, subscriptionID,
...@@ -142,11 +159,20 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -142,11 +159,20 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
firstToken, firstToken,
createdAt, createdAt,
} }
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil { if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
return err if errors.Is(err, sql.ErrNoRows) && requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
return false, err
}
log.RateMultiplier = rateMultiplier
return false, nil
} else {
return false, err
}
} }
log.RateMultiplier = rateMultiplier log.RateMultiplier = rateMultiplier
return nil return true, nil
} }
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
...@@ -183,7 +209,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param ...@@ -183,7 +209,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params) return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
} }
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params) return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
} }
...@@ -270,8 +296,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ...@@ -270,8 +296,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
r.sql, r.sql,
apiKeyStatsQuery, apiKeyStatsQuery,
[]any{service.StatusActive}, []any{service.StatusActive},
&stats.TotalApiKeys, &stats.TotalAPIKeys,
&stats.ActiveApiKeys, &stats.ActiveAPIKeys,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
...@@ -418,8 +444,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID ...@@ -418,8 +444,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
return &stats, nil return &stats, nil
} }
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation // GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := ` query := `
SELECT SELECT
COUNT(*) as total_requests, COUNT(*) as total_requests,
...@@ -623,7 +649,7 @@ func resolveUsageStatsTimezone() string { ...@@ -623,7 +649,7 @@ func resolveUsageStatsTimezone() string {
return "UTC" return "UTC"
} }
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err return logs, nil, err
...@@ -709,11 +735,11 @@ type ModelStat = usagestats.ModelStat ...@@ -709,11 +735,11 @@ type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point // UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// ApiKeyUsageTrendPoint represents API key usage trend data point // APIKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date // GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) { func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"
if granularity == "hour" { if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00" dateFormat = "YYYY-MM-DD HH24:00"
...@@ -755,10 +781,10 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, ...@@ -755,10 +781,10 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
} }
}() }()
results = make([]ApiKeyUsageTrendPoint, 0) results = make([]APIKeyUsageTrendPoint, 0)
for rows.Next() { for rows.Next() {
var row ApiKeyUsageTrendPoint var row APIKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil { if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err return nil, err
} }
results = append(results, row) results = append(results, row)
...@@ -844,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -844,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql, r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID}, []any{userID},
&stats.TotalApiKeys, &stats.TotalAPIKeys,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
...@@ -853,7 +879,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i ...@@ -853,7 +879,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql, r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive}, []any{userID, service.StatusActive},
&stats.ActiveApiKeys, &stats.ActiveAPIKeys,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
...@@ -1023,9 +1049,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat ...@@ -1023,9 +1049,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID) args = append(args, filters.UserID)
} }
if filters.ApiKeyID > 0 { if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.ApiKeyID) args = append(args, filters.APIKeyID)
} }
if filters.AccountID > 0 { if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
...@@ -1145,18 +1171,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs ...@@ -1145,18 +1171,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
return result, nil return result, nil
} }
// BatchApiKeyUsageStats represents usage stats for a single API key // BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys // GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) { func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchApiKeyUsageStats) result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return result, nil return result, nil
} }
for _, id := range apiKeyIDs { for _, id := range apiKeyIDs {
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id} result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
} }
query := ` query := `
...@@ -1582,7 +1608,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo ...@@ -1582,7 +1608,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if err != nil { if err != nil {
return err return err
} }
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs) apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
if err != nil { if err != nil {
return err return err
} }
...@@ -1603,8 +1629,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo ...@@ -1603,8 +1629,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if user, ok := users[logs[i].UserID]; ok { if user, ok := users[logs[i].UserID]; ok {
logs[i].User = user logs[i].User = user
} }
if key, ok := apiKeys[logs[i].ApiKeyID]; ok { if key, ok := apiKeys[logs[i].APIKeyID]; ok {
logs[i].ApiKey = key logs[i].APIKey = key
} }
if acc, ok := accounts[logs[i].AccountID]; ok { if acc, ok := accounts[logs[i].AccountID]; ok {
logs[i].Account = acc logs[i].Account = acc
...@@ -1642,7 +1668,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs { ...@@ -1642,7 +1668,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
for i := range logs { for i := range logs {
userIDs[logs[i].UserID] = struct{}{} userIDs[logs[i].UserID] = struct{}{}
apiKeyIDs[logs[i].ApiKeyID] = struct{}{} apiKeyIDs[logs[i].APIKeyID] = struct{}{}
accountIDs[logs[i].AccountID] = struct{}{} accountIDs[logs[i].AccountID] = struct{}{}
if logs[i].GroupID != nil { if logs[i].GroupID != nil {
groupIDs[*logs[i].GroupID] = struct{}{} groupIDs[*logs[i].GroupID] = struct{}{}
...@@ -1676,12 +1702,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in ...@@ -1676,12 +1702,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
return out, nil return out, nil
} }
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) { func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
out := make(map[int64]*service.ApiKey) out := make(map[int64]*service.APIKey)
if len(ids) == 0 { if len(ids) == 0 {
return out, nil return out, nil
} }
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx) models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1800,7 +1826,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -1800,7 +1826,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
log := &service.UsageLog{ log := &service.UsageLog{
ID: id, ID: id,
UserID: userID, UserID: userID,
ApiKeyID: apiKeyID, APIKeyID: apiKeyID,
AccountID: accountID, AccountID: accountID,
Model: model, Model: model,
InputTokens: inputTokens, InputTokens: inputTokens,
......
...@@ -7,6 +7,8 @@ import ( ...@@ -7,6 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/uuid"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
...@@ -35,11 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) { ...@@ -35,11 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite)) suite.Run(t, new(UsageLogRepoSuite))
} }
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog { func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: uuid.New().String(), // Generate unique RequestID for each log
Model: "claude-3", Model: "claude-3",
InputTokens: inputTokens, InputTokens: inputTokens,
OutputTokens: outputTokens, OutputTokens: outputTokens,
...@@ -47,7 +50,8 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A ...@@ -47,7 +50,8 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
ActualCost: cost, ActualCost: cost,
CreatedAt: createdAt, CreatedAt: createdAt,
} }
s.Require().NoError(s.repo.Create(s.ctx, log)) _, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
return log return log
} }
...@@ -55,12 +59,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A ...@@ -55,12 +59,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
func (s *UsageLogRepoSuite) TestCreate() { func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
log := &service.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3", Model: "claude-3",
InputTokens: 10, InputTokens: 10,
...@@ -69,14 +73,14 @@ func (s *UsageLogRepoSuite) TestCreate() { ...@@ -69,14 +73,14 @@ func (s *UsageLogRepoSuite) TestCreate() {
ActualCost: 0.4, ActualCost: 0.4,
} }
err := s.repo.Create(s.ctx, log) _, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err, "Create") s.Require().NoError(err, "Create")
s.Require().NotZero(log.ID) s.Require().NotZero(log.ID)
} }
func (s *UsageLogRepoSuite) TestGetByID() { func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -96,7 +100,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() { ...@@ -96,7 +100,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
func (s *UsageLogRepoSuite) TestDelete() { func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -112,7 +116,7 @@ func (s *UsageLogRepoSuite) TestDelete() { ...@@ -112,7 +116,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
func (s *UsageLogRepoSuite) TestListByUser() { func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -124,18 +128,18 @@ func (s *UsageLogRepoSuite) TestListByUser() { ...@@ -124,18 +128,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
s.Require().Equal(int64(2), page.Total) s.Require().Equal(int64(2), page.Total)
} }
// --- ListByApiKey --- // --- ListByAPIKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() { func (s *UsageLogRepoSuite) TestListByAPIKey() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByApiKey") s.Require().NoError(err, "ListByAPIKey")
s.Require().Len(logs, 2) s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total) s.Require().Equal(int64(2), page.Total)
} }
...@@ -144,7 +148,7 @@ func (s *UsageLogRepoSuite) TestListByApiKey() { ...@@ -144,7 +148,7 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
func (s *UsageLogRepoSuite) TestListByAccount() { func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -159,7 +163,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() { ...@@ -159,7 +163,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
func (s *UsageLogRepoSuite) TestGetUserStats() { func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -179,7 +183,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() { ...@@ -179,7 +183,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
func (s *UsageLogRepoSuite) TestListWithFilters() { func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -211,8 +215,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -211,8 +215,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
}) })
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"}) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled}) mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute) resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true}) accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
...@@ -223,7 +227,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -223,7 +227,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
d1, d2, d3 := 100, 200, 300 d1, d2, d3 := 100, 200, 300
logToday := &service.UsageLog{ logToday := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
ApiKeyID: apiKey1.ID, APIKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
Model: "claude-3", Model: "claude-3",
GroupID: &group.ID, GroupID: &group.ID,
...@@ -236,11 +240,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -236,11 +240,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
DurationMs: &d1, DurationMs: &d1,
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)), CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
} }
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday") _, err = s.repo.Create(s.ctx, logToday)
s.Require().NoError(err, "Create logToday")
logOld := &service.UsageLog{ logOld := &service.UsageLog{
UserID: userOld.ID, UserID: userOld.ID,
ApiKeyID: apiKey1.ID, APIKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
Model: "claude-3", Model: "claude-3",
InputTokens: 5, InputTokens: 5,
...@@ -250,11 +255,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -250,11 +255,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
DurationMs: &d2, DurationMs: &d2,
CreatedAt: todayStart.Add(-1 * time.Hour), CreatedAt: todayStart.Add(-1 * time.Hour),
} }
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld") _, err = s.repo.Create(s.ctx, logOld)
s.Require().NoError(err, "Create logOld")
logPerf := &service.UsageLog{ logPerf := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
ApiKeyID: apiKey1.ID, APIKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
Model: "claude-3", Model: "claude-3",
InputTokens: 1, InputTokens: 1,
...@@ -264,7 +270,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -264,7 +270,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
DurationMs: &d3, DurationMs: &d3,
CreatedAt: now.Add(-30 * time.Second), CreatedAt: now.Add(-30 * time.Second),
} }
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf") _, err = s.repo.Create(s.ctx, logPerf)
s.Require().NoError(err, "Create logPerf")
stats, err := s.repo.GetDashboardStats(s.ctx) stats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats") s.Require().NoError(err, "GetDashboardStats")
...@@ -272,8 +279,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -272,8 +279,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch") s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch") s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch") s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch") s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch")
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch") s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch")
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch") s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch") s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch") s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
...@@ -300,14 +307,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { ...@@ -300,14 +307,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID) stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats") s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalApiKeys) s.Require().Equal(int64(1), stats.TotalAPIKeys)
s.Require().Equal(int64(1), stats.TotalRequests) s.Require().Equal(int64(1), stats.TotalRequests)
} }
...@@ -315,7 +322,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { ...@@ -315,7 +322,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
...@@ -331,8 +338,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { ...@@ -331,8 +338,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"}) user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"}) user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
...@@ -351,24 +358,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { ...@@ -351,24 +358,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
s.Require().Empty(stats) s.Require().Empty(stats)
} }
// --- GetBatchApiKeyUsageStats --- // --- GetBatchAPIKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchApiKeyUsageStats") s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
} }
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{}) stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Empty(stats) s.Require().Empty(stats)
} }
...@@ -377,7 +384,7 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { ...@@ -377,7 +384,7 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
func (s *UsageLogRepoSuite) TestGetGlobalStats() { func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -402,7 +409,7 @@ func maxTime(a, b time.Time) time.Time { ...@@ -402,7 +409,7 @@ func maxTime(a, b time.Time) time.Time {
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -417,11 +424,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { ...@@ -417,11 +424,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
s.Require().Len(logs, 2) s.Require().Len(logs, 2)
} }
// --- ListByApiKeyAndTimeRange --- // --- ListByAPIKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -431,8 +438,8 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { ...@@ -431,8 +438,8 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime) logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByApiKeyAndTimeRange") s.Require().NoError(err, "ListByAPIKeyAndTimeRange")
s.Require().Len(logs, 2) s.Require().Len(logs, 2)
} }
...@@ -440,7 +447,7 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { ...@@ -440,7 +447,7 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -459,7 +466,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { ...@@ -459,7 +466,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -467,7 +474,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -467,7 +474,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// Create logs with different models // Create logs with different models
log1 := &service.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 10, InputTokens: 10,
...@@ -476,11 +483,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -476,11 +483,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
ActualCost: 0.5, ActualCost: 0.5,
CreatedAt: base, CreatedAt: base,
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) _, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 15, InputTokens: 15,
...@@ -489,11 +497,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -489,11 +497,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
ActualCost: 0.6, ActualCost: 0.6,
CreatedAt: base.Add(30 * time.Minute), CreatedAt: base.Add(30 * time.Minute),
} }
s.Require().NoError(s.repo.Create(s.ctx, log2)) _, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
log3 := &service.UsageLog{ log3 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-sonnet", Model: "claude-3-sonnet",
InputTokens: 20, InputTokens: 20,
...@@ -502,7 +511,8 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -502,7 +511,8 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
ActualCost: 0.7, ActualCost: 0.7,
CreatedAt: base.Add(1 * time.Hour), CreatedAt: base.Add(1 * time.Hour),
} }
s.Require().NoError(s.repo.Create(s.ctx, log3)) _, err = s.repo.Create(s.ctx, log3)
s.Require().NoError(err)
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
...@@ -515,7 +525,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { ...@@ -515,7 +525,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
now := time.Now() now := time.Now()
...@@ -535,7 +545,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { ...@@ -535,7 +545,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -552,7 +562,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { ...@@ -552,7 +562,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -571,7 +581,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { ...@@ -571,7 +581,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetUserModelStats() { func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -579,7 +589,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { ...@@ -579,7 +589,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// Create logs with different models // Create logs with different models
log1 := &service.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 100, InputTokens: 100,
...@@ -588,11 +598,12 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { ...@@ -588,11 +598,12 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
ActualCost: 0.5, ActualCost: 0.5,
CreatedAt: base, CreatedAt: base,
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) _, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-sonnet", Model: "claude-3-sonnet",
InputTokens: 50, InputTokens: 50,
...@@ -601,7 +612,8 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { ...@@ -601,7 +612,8 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
ActualCost: 0.2, ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour), CreatedAt: base.Add(1 * time.Hour),
} }
s.Require().NoError(s.repo.Create(s.ctx, log2)) _, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
...@@ -618,7 +630,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { ...@@ -618,7 +630,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -646,7 +658,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { ...@@ -646,7 +658,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -665,14 +677,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { ...@@ -665,14 +677,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &service.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 100, InputTokens: 100,
...@@ -681,11 +693,12 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -681,11 +693,12 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
ActualCost: 0.5, ActualCost: 0.5,
CreatedAt: base, CreatedAt: base,
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) _, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-sonnet", Model: "claude-3-sonnet",
InputTokens: 50, InputTokens: 50,
...@@ -694,7 +707,8 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -694,7 +707,8 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
ActualCost: 0.2, ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour), CreatedAt: base.Add(1 * time.Hour),
} }
s.Require().NoError(s.repo.Create(s.ctx, log2)) _, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
...@@ -719,7 +733,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { ...@@ -719,7 +733,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
...@@ -727,7 +741,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { ...@@ -727,7 +741,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
// Create logs on different days // Create logs on different days
log1 := &service.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 100, InputTokens: 100,
...@@ -736,11 +750,12 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { ...@@ -736,11 +750,12 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
ActualCost: 0.4, ActualCost: 0.4,
CreatedAt: base.Add(12 * time.Hour), CreatedAt: base.Add(12 * time.Hour),
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) _, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-sonnet", Model: "claude-3-sonnet",
InputTokens: 50, InputTokens: 50,
...@@ -749,7 +764,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { ...@@ -749,7 +764,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
ActualCost: 0.15, ActualCost: 0.15,
CreatedAt: base.Add(36 * time.Hour), // next day CreatedAt: base.Add(36 * time.Hour), // next day
} }
s.Require().NoError(s.repo.Create(s.ctx, log2)) _, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
startTime := base startTime := base
endTime := base.Add(72 * time.Hour) endTime := base.Add(72 * time.Hour)
...@@ -782,8 +798,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { ...@@ -782,8 +798,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"}) user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"}) user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -799,12 +815,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { ...@@ -799,12 +815,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
s.Require().GreaterOrEqual(len(trend), 2) s.Require().GreaterOrEqual(len(trend), 2)
} }
// --- GetApiKeyUsageTrend --- // --- GetAPIKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -815,14 +831,14 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { ...@@ -815,14 +831,14 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour) endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10) trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend") s.Require().NoError(err, "GetAPIKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2) s.Require().GreaterOrEqual(len(trend), 2)
} }
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -832,8 +848,8 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { ...@@ -832,8 +848,8 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour) endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10) trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend hourly") s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
...@@ -841,12 +857,12 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { ...@@ -841,12 +857,12 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID} filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters) logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey") s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1) s.Require().Len(logs, 1)
...@@ -855,7 +871,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { ...@@ -855,7 +871,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -874,7 +890,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { ...@@ -874,7 +890,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
...@@ -885,7 +901,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { ...@@ -885,7 +901,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{ filters := usagestats.UsageLogFilters{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
StartTime: &startTime, StartTime: &startTime,
EndTime: &endTime, EndTime: &endTime,
} }
......
...@@ -4,12 +4,13 @@ import ( ...@@ -4,12 +4,13 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"sort" "sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -17,14 +18,15 @@ import ( ...@@ -17,14 +18,15 @@ import (
type userRepository struct { type userRepository struct {
client *dbent.Client client *dbent.Client
sql sqlExecutor
} }
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository { func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB) return newUserRepositoryWithSQL(client, sqlDB)
} }
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository { func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client} return &userRepository{client: client, sql: sqlq}
} }
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error { func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
...@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// If attribute filters are specified, we need to filter by user IDs first // If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64 var allowedUserIDs []int64
if len(filters.Attributes) > 0 { if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes) var attrErr error
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
if attrErr != nil {
return nil, nil, attrErr
}
if len(allowedUserIDs) == 0 { if len(allowedUserIDs) == 0 {
// No users match the attribute filters // No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil return []service.User{}, paginationResultFromTotal(0, params), nil
...@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
} }
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters // filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 { func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 { if len(attrs) == 0 {
return nil return nil, nil
} }
// For each attribute filter, get the set of matching user IDs if r.sql == nil {
// Then intersect all sets to get users matching ALL filters return nil, fmt.Errorf("sql executor is not configured")
var resultSet map[int64]struct{} }
first := true
clauses := make([]string, 0, len(attrs))
args := make([]any, 0, len(attrs)*2+1)
argIndex := 1
for attrID, value := range attrs { for attrID, value := range attrs {
// Query user_attribute_values for this attribute clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
values, err := r.client.UserAttributeValue.Query(). args = append(args, attrID, "%"+value+"%")
Where( argIndex += 2
userattributevalue.AttributeIDEQ(attrID), }
userattributevalue.ValueContainsFold(value),
). query := fmt.Sprintf(
All(ctx) `SELECT user_id
if err != nil { FROM user_attribute_values
continue WHERE %s
} GROUP BY user_id
HAVING COUNT(DISTINCT attribute_id) = $%d`,
currentSet := make(map[int64]struct{}, len(values)) strings.Join(clauses, " OR "),
for _, v := range values { argIndex,
currentSet[v.UserID] = struct{}{} )
} args = append(args, len(attrs))
if first { rows, err := r.sql.QueryContext(ctx, query, args...)
resultSet = currentSet if err != nil {
first = false return nil, err
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
}
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
} }
defer func() { _ = rows.Close() }()
result := make([]int64, 0, len(resultSet)) result := make([]int64, 0)
for userID := range resultSet { for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
return nil, scanErr
}
result = append(result, userID) result = append(result, userID)
} }
return result if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
} }
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
......
...@@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc ...@@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewUserRepository, NewUserRepository,
NewApiKeyRepository, NewAPIKeyRepository,
NewGroupRepository, NewGroupRepository,
NewAccountRepository, NewAccountRepository,
NewProxyRepository, NewProxyRepository,
...@@ -42,7 +42,8 @@ var ProviderSet = wire.NewSet( ...@@ -42,7 +42,8 @@ var ProviderSet = wire.NewSet(
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
NewBillingCache, NewBillingCache,
NewApiKeyCache, NewAPIKeyCache,
NewTempUnschedCache,
ProvideConcurrencyCache, ProvideConcurrencyCache,
NewEmailCache, NewEmailCache,
NewIdentityCache, NewIdentityCache,
......
...@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/keys (paginated)", name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) { setup: func(t *testing.T, deps *contractDeps) {
t.Helper() t.Helper()
deps.apiKeyRepo.MustSeed(&service.ApiKey{ deps.apiKeyRepo.MustSeed(&service.APIKey{
ID: 100, ID: 100,
UserID: 1, UserID: 1,
Key: "sk_custom_1234567890", Key: "sk_custom_1234567890",
...@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 1, ID: 1,
UserID: 1, UserID: 1,
ApiKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
Model: "claude-3", Model: "claude-3",
InputTokens: 10, InputTokens: 10,
...@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 2, ID: 2,
UserID: 1, UserID: 1,
ApiKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
Model: "claude-3", Model: "claude-3",
InputTokens: 5, InputTokens: 5,
...@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 1, ID: 1,
UserID: 1, UserID: 1,
ApiKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
RequestID: "req_123", RequestID: "req_123",
Model: "claude-3", Model: "claude-3",
...@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) { ...@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyRegistrationEnabled: "true", service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false", service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySmtpHost: "smtp.example.com", service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySmtpPort: "587", service.SettingKeySMTPPort: "587",
service.SettingKeySmtpUsername: "user", service.SettingKeySMTPUsername: "user",
service.SettingKeySmtpPassword: "secret", service.SettingKeySMTPPassword: "secret",
service.SettingKeySmtpFrom: "no-reply@example.com", service.SettingKeySMTPFrom: "no-reply@example.com",
service.SettingKeySmtpFromName: "Sub2API", service.SettingKeySMTPFromName: "Sub2API",
service.SettingKeySmtpUseTLS: "true", service.SettingKeySMTPUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true", service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key", service.SettingKeyTurnstileSiteKey: "site-key",
...@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeySiteName: "Sub2API", service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "", service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle", service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyApiBaseUrl: "https://api.example.com", service.SettingKeyAPIBaseURL: "https://api.example.com",
service.SettingKeyContactInfo: "support", service.SettingKeyContactInfo: "support",
service.SettingKeyDocUrl: "https://docs.example.com", service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5", service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25", service.SettingKeyDefaultBalance: "1.25",
...@@ -308,7 +308,12 @@ func TestAPIContracts(t *testing.T) { ...@@ -308,7 +308,12 @@ func TestAPIContracts(t *testing.T) {
"contact_info": "support", "contact_info": "support",
"doc_url": "https://docs.example.com", "doc_url": "https://docs.example.com",
"default_concurrency": 5, "default_concurrency": 5,
"default_balance": 1.25 "default_balance": 1.25,
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro",
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o"
} }
}`, }`,
}, },
...@@ -366,16 +371,16 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -366,16 +371,16 @@ func newContractDeps(t *testing.T) *contractDeps {
cfg := &config.Config{ cfg := &config.Config{
Default: config.DefaultConfig{ Default: config.DefaultConfig{
ApiKeyPrefix: "sk-", APIKeyPrefix: "sk-",
}, },
RunMode: config.RunModeStandard, RunMode: config.RunModeStandard,
} }
userService := service.NewUserService(userRepo) userService := service.NewUserService(userRepo)
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo) usageService := service.NewUsageService(usageRepo, userRepo, nil)
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
...@@ -664,20 +669,20 @@ type stubApiKeyRepo struct { ...@@ -664,20 +669,20 @@ type stubApiKeyRepo struct {
now time.Time now time.Time
nextID int64 nextID int64
byID map[int64]*service.ApiKey byID map[int64]*service.APIKey
byKey map[string]*service.ApiKey byKey map[string]*service.APIKey
} }
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo { func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
return &stubApiKeyRepo{ return &stubApiKeyRepo{
now: now, now: now,
nextID: 100, nextID: 100,
byID: make(map[int64]*service.ApiKey), byID: make(map[int64]*service.APIKey),
byKey: make(map[string]*service.ApiKey), byKey: make(map[string]*service.APIKey),
} }
} }
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) { func (r *stubApiKeyRepo) MustSeed(key *service.APIKey) {
if key == nil { if key == nil {
return return
} }
...@@ -686,7 +691,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) { ...@@ -686,7 +691,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
r.byKey[clone.Key] = &clone r.byKey[clone.Key] = &clone
} }
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
if key == nil { if key == nil {
return errors.New("nil key") return errors.New("nil key")
} }
...@@ -706,10 +711,10 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error ...@@ -706,10 +711,10 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error
return nil return nil
} }
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
clone := *key clone := *key
return &clone, nil return &clone, nil
...@@ -718,26 +723,26 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey ...@@ -718,26 +723,26 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return 0, service.ErrApiKeyNotFound return 0, service.ErrAPIKeyNotFound
} }
return key.UserID, nil return key.UserID, nil
} }
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
found, ok := r.byKey[key] found, ok := r.byKey[key]
if !ok { if !ok {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
clone := *found clone := *found
return &clone, nil return &clone, nil
} }
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
if key == nil { if key == nil {
return errors.New("nil key") return errors.New("nil key")
} }
if _, ok := r.byID[key.ID]; !ok { if _, ok := r.byID[key.ID]; !ok {
return service.ErrApiKeyNotFound return service.ErrAPIKeyNotFound
} }
if key.UpdatedAt.IsZero() { if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now key.UpdatedAt = r.now
...@@ -751,14 +756,14 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error ...@@ -751,14 +756,14 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return service.ErrApiKeyNotFound return service.ErrAPIKeyNotFound
} }
delete(r.byID, id) delete(r.byID, id)
delete(r.byKey, key.Key) delete(r.byKey, key.Key)
return nil return nil
} }
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID)) ids := make([]int64, 0, len(r.byID))
for id := range r.byID { for id := range r.byID {
if r.byID[id].UserID == userID { if r.byID[id].UserID == userID {
...@@ -776,7 +781,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params ...@@ -776,7 +781,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
end = len(ids) end = len(ids)
} }
out := make([]service.ApiKey, 0, end-start) out := make([]service.APIKey, 0, end-start)
for _, id := range ids[start:end] { for _, id := range ids[start:end] {
clone := *r.byID[id] clone := *r.byID[id]
out = append(out, clone) out = append(out, clone)
...@@ -830,11 +835,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err ...@@ -830,11 +835,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
return ok, nil return ok, nil
} }
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -858,8 +863,8 @@ func (r *stubUsageLogRepo) SetUserLogs(userID int64, logs []service.UsageLog) { ...@@ -858,8 +863,8 @@ func (r *stubUsageLogRepo) SetUserLogs(userID int64, logs []service.UsageLog) {
r.userLogs[userID] = logs r.userLogs[userID] = logs
} }
func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) error { func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
return errors.New("not implemented") return false, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) { func (r *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
...@@ -877,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params ...@@ -877,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
return out, paginationResult(total, params), nil return out, paginationResult(total, params), nil
} }
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
...@@ -890,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in ...@@ -890,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
} }
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
...@@ -922,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi ...@@ -922,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) { func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -975,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in ...@@ -975,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
}, nil }, nil
} }
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -995,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [ ...@@ -995,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) { func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -1017,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio ...@@ -1017,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
// Apply filters // Apply filters
var filtered []service.UsageLog var filtered []service.UsageLog
for _, log := range logs { for _, log := range logs {
// Apply ApiKeyID filter // Apply APIKeyID filter
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID { if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID {
continue continue
} }
// Apply Model filter // Apply Model filter
...@@ -1151,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati ...@@ -1151,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
// Ensure compile-time interface compliance. // Ensure compile-time interface compliance.
var ( var (
_ service.UserRepository = (*stubUserRepo)(nil) _ service.UserRepository = (*stubUserRepo)(nil)
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil) _ service.APIKeyRepository = (*stubApiKeyRepo)(nil)
_ service.ApiKeyCache = (*stubApiKeyCache)(nil) _ service.APIKeyCache = (*stubApiKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil) _ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil) _ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil) _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
......
// Package server provides HTTP server initialization and configuration.
package server package server
import ( import (
...@@ -25,8 +26,8 @@ func ProvideRouter( ...@@ -25,8 +26,8 @@ func ProvideRouter(
handlers *handler.Handlers, handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
) *gin.Engine { ) *gin.Engine {
if cfg.Server.Mode == "release" { if cfg.Server.Mode == "release" {
......
// Package middleware provides HTTP middleware for authentication, authorization, and request processing.
package middleware package middleware
import ( import (
...@@ -32,7 +33,7 @@ func adminAuth( ...@@ -32,7 +33,7 @@ func adminAuth(
// 检查 x-api-key header(Admin API Key 认证) // 检查 x-api-key header(Admin API Key 认证)
apiKey := c.GetHeader("x-api-key") apiKey := c.GetHeader("x-api-key")
if apiKey != "" { if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userService) { if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return return
} }
c.Next() c.Next()
...@@ -57,14 +58,14 @@ func adminAuth( ...@@ -57,14 +58,14 @@ func adminAuth(
} }
} }
// validateAdminApiKey 验证管理员 API Key // validateAdminAPIKey 验证管理员 API Key
func validateAdminApiKey( func validateAdminAPIKey(
c *gin.Context, c *gin.Context,
key string, key string,
settingService *service.SettingService, settingService *service.SettingService,
userService *service.UserService, userService *service.UserService,
) bool { ) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context()) storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
if err != nil { if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error") AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false return false
......
...@@ -11,13 +11,13 @@ import ( ...@@ -11,13 +11,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件 // NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware { func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) APIKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg)) return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
} }
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme) // 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
...@@ -60,7 +60,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -60,7 +60,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key // 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) { if errors.Is(err, service.ErrAPIKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return return
} }
...@@ -88,7 +88,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -88,7 +88,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
if cfg.RunMode == config.RunModeSimple { if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文 // 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -146,7 +146,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -146,7 +146,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
} }
// 将API key和用户信息存入上下文 // 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -157,13 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -157,13 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
} }
} }
// GetApiKeyFromContext 从上下文中获取API key // GetAPIKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) { func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey)) value, exists := c.Get(string(ContextKeyAPIKey))
if !exists { if !exists {
return nil, false return nil, false
} }
apiKey, ok := value.(*service.ApiKey) apiKey, ok := value.(*service.APIKey)
return apiKey, ok return apiKey, ok
} }
......
...@@ -11,16 +11,16 @@ import ( ...@@ -11,16 +11,16 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth. // APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc { func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg) return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
} }
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors: // APIKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}} // {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
// //
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c) apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" { if apiKeyString == "" {
...@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) { if errors.Is(err, service.ErrAPIKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key") abortWithGoogleError(c, 401, "Invalid API key")
return return
} }
...@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
// 简易模式:跳过余额和订阅检查 // 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple { if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
} }
} }
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
......
...@@ -16,53 +16,53 @@ import ( ...@@ -16,53 +16,53 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type fakeApiKeyRepo struct { type fakeAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error) getByKey func(ctx context.Context, key string) (*service.APIKey, error)
} }
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if f.getByKey == nil { if f.getByKey == nil {
return nil, errors.New("unexpected call") return nil, errors.New("unexpected call")
} }
return f.getByKey(ctx, key) return f.getByKey(ctx, key)
} }
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error { func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
...@@ -74,8 +74,8 @@ type googleErrorResponse struct { ...@@ -74,8 +74,8 @@ type googleErrorResponse struct {
} `json:"error"` } `json:"error"`
} }
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService { func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
return service.NewApiKeyService( return service.NewAPIKeyService(
repo, repo,
nil, // userRepo (unused in GetByKey) nil, // userRepo (unused in GetByKey)
nil, // groupRepo nil, // groupRepo
...@@ -89,12 +89,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { ...@@ -89,12 +89,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called") return nil, errors.New("should not be called")
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -113,12 +113,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { ...@@ -113,12 +113,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -138,12 +138,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) { ...@@ -138,12 +138,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("db down") return nil, errors.New("db down")
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -163,9 +163,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { ...@@ -163,9 +163,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.ApiKey{ return &service.APIKey{
ID: 1, ID: 1,
Key: key, Key: key,
Status: service.StatusDisabled, Status: service.StatusDisabled,
...@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { ...@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
}, nil }, nil
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -196,9 +196,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { ...@@ -196,9 +196,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.ApiKey{ return &service.APIKey{
ID: 1, ID: 1,
Key: key, Key: key,
Status: service.StatusActive, Status: service.StatusActive,
...@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { ...@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
}, nil }, nil
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
......
...@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
Balance: 10, Balance: 10,
Concurrency: 3, Concurrency: 3,
} }
apiKey := &service.ApiKey{ apiKey := &service.APIKey{
ID: 100, ID: 100,
UserID: user.ID, UserID: user.ID,
Key: "test-key", Key: "test-key",
...@@ -46,9 +46,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -46,9 +46,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
apiKey.GroupID = &group.ID apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{ apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key { if key != apiKey.Key {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
clone := *apiKey clone := *apiKey
return &clone, nil return &clone, nil
...@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
...@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard} cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now() now := time.Now()
sub := &service.UserSubscription{ sub := &service.UserSubscription{
...@@ -110,9 +110,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -110,9 +110,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
}) })
} }
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.GET("/t", func(c *gin.Context) { router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true}) c.JSON(http.StatusOK, gin.H{"ok": true})
}) })
...@@ -120,14 +120,14 @@ func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService ...@@ -120,14 +120,14 @@ func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService
} }
type stubApiKeyRepo struct { type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error) getByKey func(ctx context.Context, key string) (*service.APIKey, error)
} }
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -135,14 +135,14 @@ func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error ...@@ -135,14 +135,14 @@ func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if r.getByKey != nil { if r.getByKey != nil {
return r.getByKey(ctx, key) return r.getByKey(ctx, key)
} }
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
...@@ -150,7 +150,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { ...@@ -150,7 +150,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
...@@ -166,11 +166,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err ...@@ -166,11 +166,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
......
...@@ -15,8 +15,8 @@ const ( ...@@ -15,8 +15,8 @@ const (
ContextKeyUser ContextKey = "user" ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色(string) // ContextKeyUserRole 当前用户角色(string)
ContextKeyUserRole ContextKey = "user_role" ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键 // ContextKeyAPIKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key" ContextKeyAPIKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键 // ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription" ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由) // ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
......
...@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc ...@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型 // AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型 // APIKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc type APIKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入 // ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware, NewJWTAuthMiddleware,
NewAdminAuthMiddleware, NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware, NewAPIKeyAuthMiddleware,
) )
...@@ -17,8 +17,8 @@ func SetupRouter( ...@@ -17,8 +17,8 @@ func SetupRouter(
handlers *handler.Handlers, handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) *gin.Engine { ) *gin.Engine {
...@@ -43,8 +43,8 @@ func registerRoutes( ...@@ -43,8 +43,8 @@ func registerRoutes(
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) { ) {
......
// Package routes provides HTTP route registration and handlers.
package routes package routes
import ( import (
...@@ -67,10 +68,10 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -67,10 +68,10 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats) dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend) dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage) dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
} }
} }
...@@ -123,6 +124,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -123,6 +124,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/usage", h.Admin.Account.GetUsage) accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate) accounts.POST("/batch", h.Admin.Account.BatchCreate)
...@@ -203,12 +206,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -203,12 +206,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{ {
adminSettings.GET("", h.Admin.Setting.GetSettings) adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings) adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection) adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail) adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理 // Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey) adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey) adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey) adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
} }
} }
...@@ -248,7 +251,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -248,7 +251,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("", h.Admin.Usage.List) usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats) usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers) usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys) usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
} }
} }
......
...@@ -13,8 +13,8 @@ import ( ...@@ -13,8 +13,8 @@ import (
func RegisterGatewayRoutes( func RegisterGatewayRoutes(
r *gin.Engine, r *gin.Engine,
h *handler.Handlers, h *handler.Handlers,
apiKeyAuth middleware.ApiKeyAuthMiddleware, apiKeyAuth middleware.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) { ) {
...@@ -36,7 +36,7 @@ func RegisterGatewayRoutes( ...@@ -36,7 +36,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta") gemini := r.Group("/v1beta")
gemini.Use(bodyLimit) gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
...@@ -65,7 +65,7 @@ func RegisterGatewayRoutes( ...@@ -65,7 +65,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta := r.Group("/antigravity/v1beta") antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit) antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
......
...@@ -50,7 +50,7 @@ func RegisterUserRoutes( ...@@ -50,7 +50,7 @@ func RegisterUserRoutes(
usage.GET("/dashboard/stats", h.Usage.DashboardStats) usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend) usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels) usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage) usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
} }
// 卡密兑换 // 卡密兑换
......
// Package service provides business logic and domain services for the application.
package service package service
import ( import (
...@@ -29,6 +30,9 @@ type Account struct { ...@@ -29,6 +30,9 @@ type Account struct {
RateLimitResetAt *time.Time RateLimitResetAt *time.Time
OverloadUntil *time.Time OverloadUntil *time.Time
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
SessionWindowStart *time.Time SessionWindowStart *time.Time
SessionWindowEnd *time.Time SessionWindowEnd *time.Time
SessionWindowStatus string SessionWindowStatus string
...@@ -39,6 +43,13 @@ type Account struct { ...@@ -39,6 +43,13 @@ type Account struct {
Groups []*Group Groups []*Group
} }
type TempUnschedulableRule struct {
ErrorCode int `json:"error_code"`
Keywords []string `json:"keywords"`
DurationMinutes int `json:"duration_minutes"`
Description string `json:"description"`
}
func (a *Account) IsActive() bool { func (a *Account) IsActive() bool {
return a.Status == StatusActive return a.Status == StatusActive
} }
...@@ -54,6 +65,9 @@ func (a *Account) IsSchedulable() bool { ...@@ -54,6 +65,9 @@ func (a *Account) IsSchedulable() bool {
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) { if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false return false
} }
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
return true return true
} }
...@@ -92,10 +106,7 @@ func (a *Account) GeminiOAuthType() string { ...@@ -92,10 +106,7 @@ func (a *Account) GeminiOAuthType() string {
func (a *Account) GeminiTierID() string { func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id")) tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" { return tierID
return ""
}
return strings.ToUpper(tierID)
} }
func (a *Account) IsGeminiCodeAssist() bool { func (a *Account) IsGeminiCodeAssist() bool {
...@@ -163,6 +174,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time { ...@@ -163,6 +174,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil return nil
} }
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false
}
raw, ok := a.Credentials["temp_unschedulable_enabled"]
if !ok || raw == nil {
return false
}
enabled, ok := raw.(bool)
return ok && enabled
}
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["temp_unschedulable_rules"]
if !ok || raw == nil {
return nil
}
arr, ok := raw.([]any)
if !ok {
return nil
}
rules := make([]TempUnschedulableRule, 0, len(arr))
for _, item := range arr {
entry, ok := item.(map[string]any)
if !ok || entry == nil {
continue
}
rule := TempUnschedulableRule{
ErrorCode: parseTempUnschedInt(entry["error_code"]),
Keywords: parseTempUnschedStrings(entry["keywords"]),
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
Description: parseTempUnschedString(entry["description"]),
}
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
continue
}
rules = append(rules, rule)
}
return rules
}
func parseTempUnschedString(value any) string {
s, ok := value.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func parseTempUnschedStrings(value any) []string {
if value == nil {
return nil
}
var raw []string
switch v := value.(type) {
case []string:
raw = v
case []any:
raw = make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
raw = append(raw, s)
}
}
default:
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
s := strings.TrimSpace(item)
if s != "" {
out = append(out, s)
}
}
return out
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
func (a *Account) GetModelMapping() map[string]string { func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil { if a.Credentials == nil {
return nil return nil
...@@ -206,7 +325,7 @@ func (a *Account) GetMappedModel(requestedModel string) string { ...@@ -206,7 +325,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
} }
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey { if a.Type != AccountTypeAPIKey {
return "" return ""
} }
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
...@@ -229,7 +348,7 @@ func (a *Account) GetExtraString(key string) string { ...@@ -229,7 +348,7 @@ func (a *Account) GetExtraString(key string) string {
} }
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil { if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false return false
} }
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok { if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
...@@ -301,14 +420,14 @@ func (a *Account) IsOpenAIOAuth() bool { ...@@ -301,14 +420,14 @@ func (a *Account) IsOpenAIOAuth() bool {
} }
func (a *Account) IsOpenAIApiKey() bool { func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey return a.IsOpenAI() && a.Type == AccountTypeAPIKey
} }
func (a *Account) GetOpenAIBaseURL() string { func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
} }
if a.Type == AccountTypeApiKey { if a.Type == AccountTypeAPIKey {
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
if baseURL != "" { if baseURL != "" {
return baseURL return baseURL
......
...@@ -49,6 +49,8 @@ type AccountRepository interface { ...@@ -49,6 +49,8 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment