Commit 112a2d08 authored by ianshaw's avatar ianshaw
Browse files

chore: 更新依赖、配置和代码生成

主要更新:
- 更新 go.mod/go.sum 依赖
- 重新生成 Ent ORM 代码
- 更新 Wire 依赖注入配置
- 添加 docker-compose.override.yml 到 .gitignore
- 更新 README 文档(Simple Mode 说明和已知问题)
- 清理调试日志
- 其他代码优化和格式修复
parent b1702de5
...@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// 2. Clear group_id for api keys bound to this group. // 2. Clear group_id for api keys bound to this group.
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。 // 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。 // 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.APIKey.Update(). if _, err := txClient.ApiKey.Update().
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()). Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
ClearGroupID(). ClearGroupID().
Save(ctx); err != nil { Save(ctx); err != nil {
......
...@@ -34,15 +34,15 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai ...@@ -34,15 +34,15 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
return u return u
} }
func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) { func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// 使用全局 ent client,确保软删除验证在实际持久化数据上进行。 // 使用全局 ent client,确保软删除验证在实际持久化数据上进行。
client := testEntClient(t) client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewAPIKeyRepository(client) repo := NewApiKeyRepository(client)
key := &service.APIKey{ key := &service.ApiKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"), Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
Name: "soft-delete", Name: "soft-delete",
...@@ -53,28 +53,28 @@ func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) { ...@@ -53,28 +53,28 @@ func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key") require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
_, err := repo.GetByID(ctx, key.ID) _, err := repo.GetByID(ctx, key.ID)
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default") require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx) _, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows") require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter") require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.APIKey.Query(). got, err := client.ApiKey.Query().
Where(apikey.IDEQ(key.ID)). Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx)) Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows") require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete") require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
} }
func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) { func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// 使用全局 ent client,避免事务回滚影响幂等性验证。 // 使用全局 ent client,避免事务回滚影响幂等性验证。
client := testEntClient(t) client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewAPIKeyRepository(client) repo := NewApiKeyRepository(client)
key := &service.APIKey{ key := &service.ApiKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"), Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
Name: "soft-delete2", Name: "soft-delete2",
...@@ -86,15 +86,15 @@ func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) { ...@@ -86,15 +86,15 @@ func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent") require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
} }
func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) { func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// 使用全局 ent client,确保 SkipSoftDelete 的硬删除语义可验证。 // 使用全局 ent client,确保 SkipSoftDelete 的硬删除语义可验证。
client := testEntClient(t) client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewAPIKeyRepository(client) repo := NewApiKeyRepository(client)
key := &service.APIKey{ key := &service.ApiKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"), Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
Name: "soft-delete3", Name: "soft-delete3",
...@@ -105,10 +105,10 @@ func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) { ...@@ -105,10 +105,10 @@ func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key") require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at. // Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx)) _, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "hard delete") require.NoError(t, err, "hard delete")
_, err = client.APIKey.Query(). _, err = client.ApiKey.Query().
Where(apikey.IDEQ(key.ID)). Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx)) Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted") require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
......
...@@ -4,12 +4,13 @@ import ( ...@@ -4,12 +4,13 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"sort" "sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -17,14 +18,15 @@ import ( ...@@ -17,14 +18,15 @@ import (
type userRepository struct { type userRepository struct {
client *dbent.Client client *dbent.Client
sql sqlExecutor
} }
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository { func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB) return newUserRepositoryWithSQL(client, sqlDB)
} }
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository { func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client} return &userRepository{client: client, sql: sqlq}
} }
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error { func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
...@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// If attribute filters are specified, we need to filter by user IDs first // If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64 var allowedUserIDs []int64
if len(filters.Attributes) > 0 { if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes) var attrErr error
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
if attrErr != nil {
return nil, nil, attrErr
}
if len(allowedUserIDs) == 0 { if len(allowedUserIDs) == 0 {
// No users match the attribute filters // No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil return []service.User{}, paginationResultFromTotal(0, params), nil
...@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ...@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
} }
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters // filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 { func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 { if len(attrs) == 0 {
return nil return nil, nil
} }
// For each attribute filter, get the set of matching user IDs if r.sql == nil {
// Then intersect all sets to get users matching ALL filters return nil, fmt.Errorf("sql executor is not configured")
var resultSet map[int64]struct{} }
first := true
clauses := make([]string, 0, len(attrs))
args := make([]any, 0, len(attrs)*2+1)
argIndex := 1
for attrID, value := range attrs { for attrID, value := range attrs {
// Query user_attribute_values for this attribute clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
values, err := r.client.UserAttributeValue.Query(). args = append(args, attrID, "%"+value+"%")
Where( argIndex += 2
userattributevalue.AttributeIDEQ(attrID), }
userattributevalue.ValueContainsFold(value),
). query := fmt.Sprintf(
All(ctx) `SELECT user_id
if err != nil { FROM user_attribute_values
continue WHERE %s
} GROUP BY user_id
HAVING COUNT(DISTINCT attribute_id) = $%d`,
currentSet := make(map[int64]struct{}, len(values)) strings.Join(clauses, " OR "),
for _, v := range values { argIndex,
currentSet[v.UserID] = struct{}{} )
} args = append(args, len(attrs))
if first { rows, err := r.sql.QueryContext(ctx, query, args...)
resultSet = currentSet if err != nil {
first = false return nil, err
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
}
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
} }
defer rows.Close()
result := make([]int64, 0, len(resultSet)) result := make([]int64, 0)
for userID := range resultSet { for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
return nil, scanErr
}
result = append(result, userID) result = append(result, userID)
} }
return result if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
} }
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
......
...@@ -28,13 +28,12 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc ...@@ -28,13 +28,12 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewUserRepository, NewUserRepository,
NewAPIKeyRepository, NewApiKeyRepository,
NewGroupRepository, NewGroupRepository,
NewAccountRepository, NewAccountRepository,
NewProxyRepository, NewProxyRepository,
NewRedeemCodeRepository, NewRedeemCodeRepository,
NewUsageLogRepository, NewUsageLogRepository,
NewOpsRepository,
NewSettingRepository, NewSettingRepository,
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository, NewUserAttributeDefinitionRepository,
...@@ -43,7 +42,8 @@ var ProviderSet = wire.NewSet( ...@@ -43,7 +42,8 @@ var ProviderSet = wire.NewSet(
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
NewBillingCache, NewBillingCache,
NewAPIKeyCache, NewApiKeyCache,
NewTempUnschedCache,
ProvideConcurrencyCache, ProvideConcurrencyCache,
NewEmailCache, NewEmailCache,
NewIdentityCache, NewIdentityCache,
......
// Package server provides HTTP server setup and routing configuration.
package server package server
import ( import (
...@@ -26,8 +25,8 @@ func ProvideRouter( ...@@ -26,8 +25,8 @@ func ProvideRouter(
handlers *handler.Handlers, handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.APIKeyAuthMiddleware, apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.APIKeyService, apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
) *gin.Engine { ) *gin.Engine {
if cfg.Server.Mode == "release" { if cfg.Server.Mode == "release" {
......
...@@ -32,7 +32,7 @@ func adminAuth( ...@@ -32,7 +32,7 @@ func adminAuth(
// 检查 x-api-key header(Admin API Key 认证) // 检查 x-api-key header(Admin API Key 认证)
apiKey := c.GetHeader("x-api-key") apiKey := c.GetHeader("x-api-key")
if apiKey != "" { if apiKey != "" {
if !validateAdminAPIKey(c, apiKey, settingService, userService) { if !validateAdminApiKey(c, apiKey, settingService, userService) {
return return
} }
c.Next() c.Next()
...@@ -52,48 +52,19 @@ func adminAuth( ...@@ -52,48 +52,19 @@ func adminAuth(
} }
} }
// WebSocket 请求无法设置自定义 header,允许在 query 中携带凭证
if isWebSocketRequest(c) {
if token := strings.TrimSpace(c.Query("token")); token != "" {
if !validateJWTForAdmin(c, token, authService, userService) {
return
}
c.Next()
return
}
if apiKey := strings.TrimSpace(c.Query("api_key")); apiKey != "" {
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return
}
c.Next()
return
}
}
// 无有效认证信息 // 无有效认证信息
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required") AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
} }
} }
func isWebSocketRequest(c *gin.Context) bool { // validateAdminApiKey 验证管理员 API Key
if c == nil || c.Request == nil { func validateAdminApiKey(
return false
}
if strings.EqualFold(c.GetHeader("Upgrade"), "websocket") {
return true
}
conn := strings.ToLower(c.GetHeader("Connection"))
return strings.Contains(conn, "upgrade") && strings.EqualFold(c.GetHeader("Upgrade"), "websocket")
}
// validateAdminAPIKey 验证管理员 API Key
func validateAdminAPIKey(
c *gin.Context, c *gin.Context,
key string, key string,
settingService *service.SettingService, settingService *service.SettingService,
userService *service.UserService, userService *service.UserService,
) bool { ) bool {
storedKey, err := settingService.GetAdminAPIKey(c.Request.Context()) storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
if err != nil { if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error") AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false return false
......
...@@ -11,13 +11,13 @@ import ( ...@@ -11,13 +11,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件 // NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) APIKeyAuthMiddleware { func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg, opsService)) return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
} }
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) gin.HandlerFunc { func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme) // 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
...@@ -53,7 +53,6 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -53,7 +53,6 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 如果所有header都没有API key // 如果所有header都没有API key
if apiKeyString == "" { if apiKeyString == "" {
recordOpsAuthError(c, opsService, nil, 401, "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter") AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
return return
} }
...@@ -61,40 +60,35 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -61,40 +60,35 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 从数据库验证API key // 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, service.ErrAPIKeyNotFound) { if errors.Is(err, service.ErrApiKeyNotFound) {
recordOpsAuthError(c, opsService, nil, 401, "Invalid API key")
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return return
} }
recordOpsAuthError(c, opsService, nil, 500, "Failed to validate API key")
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key") AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return return
} }
// 检查API key是否激活 // 检查API key是否激活
if !apiKey.IsActive() { if !apiKey.IsActive() {
recordOpsAuthError(c, opsService, apiKey, 401, "API key is disabled")
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return return
} }
// 检查关联的用户 // 检查关联的用户
if apiKey.User == nil { if apiKey.User == nil {
recordOpsAuthError(c, opsService, apiKey, 401, "User associated with API key not found")
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found") AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return return
} }
// 检查用户状态 // 检查用户状态
if !apiKey.User.IsActive() { if !apiKey.User.IsActive() {
recordOpsAuthError(c, opsService, apiKey, 401, "User account is not active")
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active") AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return return
} }
if cfg.RunMode == config.RunModeSimple { if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文 // 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -115,14 +109,12 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -115,14 +109,12 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
apiKey.Group.ID, apiKey.Group.ID,
) )
if err != nil { if err != nil {
recordOpsAuthError(c, opsService, apiKey, 403, "No active subscription found for this group")
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group") AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return return
} }
// 验证订阅状态(是否过期、暂停等) // 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
recordOpsAuthError(c, opsService, apiKey, 403, err.Error())
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error()) AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return return
} }
...@@ -139,7 +131,6 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -139,7 +131,6 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 预检查用量限制(使用0作为额外费用进行预检查) // 预检查用量限制(使用0作为额外费用进行预检查)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
recordOpsAuthError(c, opsService, apiKey, 429, err.Error())
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error()) AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return return
} }
...@@ -149,14 +140,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -149,14 +140,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
} else { } else {
// 余额模式:检查用户余额 // 余额模式:检查用户余额
if apiKey.User.Balance <= 0 { if apiKey.User.Balance <= 0 {
recordOpsAuthError(c, opsService, apiKey, 403, "Insufficient account balance")
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance") AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return return
} }
} }
// 将API key和用户信息存入上下文 // 将API key和用户信息存入上下文
c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -167,66 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -167,66 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
} }
} }
func recordOpsAuthError(c *gin.Context, opsService *service.OpsService, apiKey *service.APIKey, status int, message string) { // GetApiKeyFromContext 从上下文中获取API key
if opsService == nil || c == nil { func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
return value, exists := c.Get(string(ContextKeyApiKey))
}
errType := "authentication_error"
phase := "auth"
severity := "P3"
switch status {
case 403:
errType = "billing_error"
phase = "billing"
case 429:
errType = "rate_limit_error"
phase = "billing"
severity = "P2"
case 500:
errType = "api_error"
phase = "internal"
severity = "P1"
}
logEntry := &service.OpsErrorLog{
Phase: phase,
Type: errType,
Severity: severity,
StatusCode: status,
Message: message,
ClientIP: c.ClientIP(),
RequestPath: func() string {
if c.Request != nil && c.Request.URL != nil {
return c.Request.URL.Path
}
return ""
}(),
}
if apiKey != nil {
logEntry.APIKeyID = &apiKey.ID
if apiKey.User != nil {
logEntry.UserID = &apiKey.User.ID
}
if apiKey.GroupID != nil {
logEntry.GroupID = apiKey.GroupID
}
if apiKey.Group != nil {
logEntry.Platform = apiKey.Group.Platform
}
}
enqueueOpsAuthErrorLog(opsService, logEntry)
}
// GetAPIKeyFromContext 从上下文中获取API key
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
value, exists := c.Get(string(ContextKeyAPIKey))
if !exists { if !exists {
return nil, false return nil, false
} }
apiKey, ok := value.(*service.APIKey) apiKey, ok := value.(*service.ApiKey)
return apiKey, ok return apiKey, ok
} }
......
...@@ -11,16 +11,16 @@ import ( ...@@ -11,16 +11,16 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// APIKeyAuthGoogle is a Google-style error wrapper for API key auth. // ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc { func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg) return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
} }
// APIKeyAuthWithSubscriptionGoogle behaves like APIKeyAuthWithSubscription but returns Google-style errors: // ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}} // {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
// //
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c) apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" { if apiKeyString == "" {
...@@ -30,7 +30,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -30,7 +30,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, service.ErrAPIKeyNotFound) { if errors.Is(err, service.ErrApiKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key") abortWithGoogleError(c, 401, "Invalid API key")
return return
} }
...@@ -53,7 +53,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -53,7 +53,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
// 简易模式:跳过余额和订阅检查 // 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple { if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -92,7 +92,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -92,7 +92,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
} }
} }
c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
......
...@@ -16,53 +16,53 @@ import ( ...@@ -16,53 +16,53 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type fakeAPIKeyRepo struct { type fakeApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error) getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
} }
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if f.getByKey == nil { if f.getByKey == nil {
return nil, errors.New("unexpected call") return nil, errors.New("unexpected call")
} }
return f.getByKey(ctx, key) return f.getByKey(ctx, key)
} }
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error { func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) { func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
...@@ -74,8 +74,8 @@ type googleErrorResponse struct { ...@@ -74,8 +74,8 @@ type googleErrorResponse struct {
} `json:"error"` } `json:"error"`
} }
func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService { func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
return service.NewAPIKeyService( return service.NewApiKeyService(
repo, repo,
nil, // userRepo (unused in GetByKey) nil, // userRepo (unused in GetByKey)
nil, // groupRepo nil, // groupRepo
...@@ -85,16 +85,16 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService ...@@ -85,16 +85,16 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
) )
} }
func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("should not be called") return nil, errors.New("should not be called")
}, },
}) })
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -109,16 +109,16 @@ func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { ...@@ -109,16 +109,16 @@ func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
} }
func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, service.ErrAPIKeyNotFound return nil, service.ErrApiKeyNotFound
}, },
}) })
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -134,16 +134,16 @@ func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { ...@@ -134,16 +134,16 @@ func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
} }
func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) { func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("db down") return nil, errors.New("db down")
}, },
}) })
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -159,13 +159,13 @@ func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) { ...@@ -159,13 +159,13 @@ func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
require.Equal(t, "INTERNAL", resp.Error.Status) require.Equal(t, "INTERNAL", resp.Error.Status)
} }
func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.APIKey{ return &service.ApiKey{
ID: 1, ID: 1,
Key: key, Key: key,
Status: service.StatusDisabled, Status: service.StatusDisabled,
...@@ -176,7 +176,7 @@ func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { ...@@ -176,7 +176,7 @@ func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
}, nil }, nil
}, },
}) })
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -192,13 +192,13 @@ func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { ...@@ -192,13 +192,13 @@ func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
} }
func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.APIKey{ return &service.ApiKey{
ID: 1, ID: 1,
Key: key, Key: key,
Status: service.StatusActive, Status: service.StatusActive,
...@@ -210,7 +210,7 @@ func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { ...@@ -210,7 +210,7 @@ func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
}, nil }, nil
}, },
}) })
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
......
...@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
Balance: 10, Balance: 10,
Concurrency: 3, Concurrency: 3,
} }
apiKey := &service.APIKey{ apiKey := &service.ApiKey{
ID: 100, ID: 100,
UserID: user.ID, UserID: user.ID,
Key: "test-key", Key: "test-key",
...@@ -45,10 +45,10 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -45,10 +45,10 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
} }
apiKey.GroupID = &group.ID apiKey.GroupID = &group.ID
apiKeyRepo := &stubAPIKeyRepo{ apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
if key != apiKey.Key { if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound return nil, service.ErrApiKeyNotFound
} }
clone := *apiKey clone := *apiKey
return &clone, nil return &clone, nil
...@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
...@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard} cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now() now := time.Now()
sub := &service.UserSubscription{ sub := &service.UserSubscription{
...@@ -110,75 +110,75 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -110,75 +110,75 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
}) })
} }
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg, nil))) router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.GET("/t", func(c *gin.Context) { router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true}) c.JSON(http.StatusOK, gin.H{"ok": true})
}) })
return router return router
} }
type stubAPIKeyRepo struct { type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error) getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
} }
func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if r.getByKey != nil { if r.getByKey != nil {
return r.getByKey(ctx, key) return r.getByKey(ctx, key)
} }
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error { func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) { func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
......
...@@ -2,14 +2,11 @@ package middleware ...@@ -2,14 +2,11 @@ package middleware
import ( import (
"log" "log"
"regexp"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var sensitiveQueryParamRE = regexp.MustCompile(`(?i)([?&](?:token|api_key)=)[^&#]*`)
// Logger 请求日志中间件 // Logger 请求日志中间件
func Logger() gin.HandlerFunc { func Logger() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
...@@ -29,7 +26,7 @@ func Logger() gin.HandlerFunc { ...@@ -29,7 +26,7 @@ func Logger() gin.HandlerFunc {
method := c.Request.Method method := c.Request.Method
// 请求路径 // 请求路径
path := sensitiveQueryParamRE.ReplaceAllString(c.Request.URL.RequestURI(), "${1}***") path := c.Request.URL.Path
// 状态码 // 状态码
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
......
// Package middleware provides HTTP middleware components for authentication,
// authorization, logging, error recovery, and request processing.
package middleware package middleware
import ( import (
...@@ -17,8 +15,8 @@ const ( ...@@ -17,8 +15,8 @@ const (
ContextKeyUser ContextKey = "user" ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色(string) // ContextKeyUserRole 当前用户角色(string)
ContextKeyUserRole ContextKey = "user_role" ContextKeyUserRole ContextKey = "user_role"
// ContextKeyAPIKey API密钥上下文键 // ContextKeyApiKey API密钥上下文键
ContextKeyAPIKey ContextKey = "api_key" ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键 // ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription" ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由) // ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
......
...@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc ...@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型 // AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc type AdminAuthMiddleware gin.HandlerFunc
// APIKeyAuthMiddleware API Key 认证中间件类型 // ApiKeyAuthMiddleware API Key 认证中间件类型
type APIKeyAuthMiddleware gin.HandlerFunc type ApiKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入 // ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware, NewJWTAuthMiddleware,
NewAdminAuthMiddleware, NewAdminAuthMiddleware,
NewAPIKeyAuthMiddleware, NewApiKeyAuthMiddleware,
) )
...@@ -17,8 +17,8 @@ func SetupRouter( ...@@ -17,8 +17,8 @@ func SetupRouter(
handlers *handler.Handlers, handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.APIKeyAuthMiddleware, apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.APIKeyService, apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) *gin.Engine { ) *gin.Engine {
...@@ -43,8 +43,8 @@ func registerRoutes( ...@@ -43,8 +43,8 @@ func registerRoutes(
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.APIKeyAuthMiddleware, apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.APIKeyService, apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) { ) {
......
// Package routes 提供 HTTP 路由注册和处理函数
package routes package routes
import ( import (
......
...@@ -50,7 +50,7 @@ func RegisterUserRoutes( ...@@ -50,7 +50,7 @@ func RegisterUserRoutes(
usage.GET("/dashboard/stats", h.Usage.DashboardStats) usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend) usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels) usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage) usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
} }
// 卡密兑换 // 卡密兑换
......
...@@ -29,6 +29,9 @@ type Account struct { ...@@ -29,6 +29,9 @@ type Account struct {
RateLimitResetAt *time.Time RateLimitResetAt *time.Time
OverloadUntil *time.Time OverloadUntil *time.Time
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
SessionWindowStart *time.Time SessionWindowStart *time.Time
SessionWindowEnd *time.Time SessionWindowEnd *time.Time
SessionWindowStatus string SessionWindowStatus string
...@@ -39,6 +42,13 @@ type Account struct { ...@@ -39,6 +42,13 @@ type Account struct {
Groups []*Group Groups []*Group
} }
type TempUnschedulableRule struct {
ErrorCode int `json:"error_code"`
Keywords []string `json:"keywords"`
DurationMinutes int `json:"duration_minutes"`
Description string `json:"description"`
}
func (a *Account) IsActive() bool { func (a *Account) IsActive() bool {
return a.Status == StatusActive return a.Status == StatusActive
} }
...@@ -54,6 +64,9 @@ func (a *Account) IsSchedulable() bool { ...@@ -54,6 +64,9 @@ func (a *Account) IsSchedulable() bool {
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) { if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false return false
} }
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
return true return true
} }
...@@ -163,6 +176,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time { ...@@ -163,6 +176,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil return nil
} }
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false
}
raw, ok := a.Credentials["temp_unschedulable_enabled"]
if !ok || raw == nil {
return false
}
enabled, ok := raw.(bool)
return ok && enabled
}
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["temp_unschedulable_rules"]
if !ok || raw == nil {
return nil
}
arr, ok := raw.([]any)
if !ok {
return nil
}
rules := make([]TempUnschedulableRule, 0, len(arr))
for _, item := range arr {
entry, ok := item.(map[string]any)
if !ok || entry == nil {
continue
}
rule := TempUnschedulableRule{
ErrorCode: parseTempUnschedInt(entry["error_code"]),
Keywords: parseTempUnschedStrings(entry["keywords"]),
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
Description: parseTempUnschedString(entry["description"]),
}
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
continue
}
rules = append(rules, rule)
}
return rules
}
func parseTempUnschedString(value any) string {
s, ok := value.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func parseTempUnschedStrings(value any) []string {
if value == nil {
return nil
}
var raw []string
switch v := value.(type) {
case []string:
raw = v
case []any:
raw = make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
raw = append(raw, s)
}
}
default:
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
s := strings.TrimSpace(item)
if s != "" {
out = append(out, s)
}
}
return out
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
func (a *Account) GetModelMapping() map[string]string { func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil { if a.Credentials == nil {
return nil return nil
...@@ -206,7 +327,7 @@ func (a *Account) GetMappedModel(requestedModel string) string { ...@@ -206,7 +327,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
} }
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey { if a.Type != AccountTypeApiKey {
return "" return ""
} }
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
...@@ -229,7 +350,7 @@ func (a *Account) GetExtraString(key string) string { ...@@ -229,7 +350,7 @@ func (a *Account) GetExtraString(key string) string {
} }
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil { if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false return false
} }
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok { if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
...@@ -300,15 +421,15 @@ func (a *Account) IsOpenAIOAuth() bool { ...@@ -300,15 +421,15 @@ func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth return a.IsOpenAI() && a.Type == AccountTypeOAuth
} }
func (a *Account) IsOpenAIAPIKey() bool { func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeAPIKey return a.IsOpenAI() && a.Type == AccountTypeApiKey
} }
func (a *Account) GetOpenAIBaseURL() string { func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
} }
if a.Type == AccountTypeAPIKey { if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
if baseURL != "" { if baseURL != "" {
return baseURL return baseURL
...@@ -338,8 +459,8 @@ func (a *Account) GetOpenAIIDToken() string { ...@@ -338,8 +459,8 @@ func (a *Account) GetOpenAIIDToken() string {
return a.GetCredential("id_token") return a.GetCredential("id_token")
} }
func (a *Account) GetOpenAIAPIKey() string { func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIAPIKey() { if !a.IsOpenAIApiKey() {
return "" return ""
} }
return a.GetCredential("api_key") return a.GetCredential("api_key")
......
// Package service 提供业务逻辑层服务,封装领域模型的业务规则和操作流程。
// 服务层协调 repository 层的数据访问,实现跨实体的业务逻辑,并为上层 API 提供统一的业务接口。
package service package service
import ( import (
...@@ -51,6 +49,8 @@ type AccountRepository interface { ...@@ -51,6 +49,8 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
......
...@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim ...@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim
panic("unexpected SetOverloaded call") panic("unexpected SetOverloaded call")
} }
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
panic("unexpected SetTempUnschedulable call")
}
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
panic("unexpected ClearTempUnschedulable call")
}
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error { func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call") panic("unexpected ClearRateLimit call")
} }
......
...@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
chatgptAccountID = account.GetChatGPTAccountID() chatgptAccountID = account.GetChatGPTAccountID()
} else if account.Type == "apikey" { } else if account.Type == "apikey" {
// API Key - use Platform API // API Key - use Platform API
authToken = account.GetOpenAIAPIKey() authToken = account.GetOpenAIApiKey()
if authToken == "" { if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available") return s.sendErrorAndEnd(c, "No API key available")
} }
...@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
} }
// For API Key accounts with model mapping, map the model // For API Key accounts with model mapping, map the model
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeApiKey {
mapping := account.GetModelMapping() mapping := account.GetModelMapping()
if len(mapping) > 0 { if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists { if mappedModel, exists := mapping[testModelID]; exists {
...@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error var err error
switch account.Type { switch account.Type {
case AccountTypeAPIKey: case AccountTypeApiKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth: case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment