Commit 2220fd18 authored by song's avatar song
Browse files

merge upstream main

parents 11ff73b5 df4c0adf
...@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) { ...@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system": []map[string]any{ "system": []map[string]any{
{ {
"type": "text", "type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.", "text": claudeCodeSystemPrompt,
"cache_control": map[string]string{ "cache_control": map[string]string{
"type": "ephemeral", "type": "ephemeral",
}, },
......
...@@ -115,6 +115,8 @@ type CreateGroupInput struct { ...@@ -115,6 +115,8 @@ type CreateGroupInput struct {
MCPXMLInject *bool MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string SupportedModelScopes []string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
} }
type UpdateGroupInput struct { type UpdateGroupInput struct {
...@@ -142,6 +144,8 @@ type UpdateGroupInput struct { ...@@ -142,6 +144,8 @@ type UpdateGroupInput struct {
MCPXMLInject *bool MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string SupportedModelScopes *[]string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
} }
type CreateAccountInput struct { type CreateAccountInput struct {
...@@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
mcpXMLInject = *input.MCPXMLInject mcpXMLInject = *input.MCPXMLInject
} }
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与新分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
var err error
accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
}
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
...@@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
} }
// 如果有需要复制的账号,绑定到新分组
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
}
group.AccountCount = int64(len(accountIDsToCopy))
}
return group, nil return group, nil
} }
...@@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
} }
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
// 校验:源分组不能是自身
if srcGroupID == id {
return nil, fmt.Errorf("cannot copy accounts from self")
}
// 去重
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与当前分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != group.Platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
// 先清空当前分组的所有账号绑定
if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
}
// 再绑定源分组的账号
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
}
}
}
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
} }
......
...@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI ...@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
panic("unexpected DeleteAccountGroupsByGroupID call") panic("unexpected DeleteAccountGroupsByGroupID call")
} }
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
type proxyRepoStub struct { type proxyRepoStub struct {
deleteErr error deleteErr error
countErr error countErr error
......
...@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context, ...@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
panic("unexpected DeleteAccountGroupsByGroupID call") panic("unexpected DeleteAccountGroupsByGroupID call")
} }
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{} repo := &groupRepoStubForAdmin{}
...@@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C ...@@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C
panic("unexpected DeleteAccountGroupsByGroupID call") panic("unexpected DeleteAccountGroupsByGroupID call")
} }
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
type groupRepoStubForInvalidRequestFallback struct { type groupRepoStubForInvalidRequestFallback struct {
groups map[int64]*Group groups map[int64]*Group
created *Group created *Group
...@@ -748,4 +764,4 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes ...@@ -748,4 +764,4 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
require.NotNil(t, group) require.NotNil(t, group)
require.NotNil(t, repo.updated) require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
} }
...@@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string { ...@@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string {
} }
// Antigravity 直接支持的模型(精确匹配透传) // Antigravity 直接支持的模型(精确匹配透传)
// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列
var antigravitySupportedModels = map[string]bool{ var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true, "claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true, "claude-sonnet-4-5": true,
"claude-sonnet-4-5-thinking": true, "claude-sonnet-4-5-thinking": true,
"gemini-2.5-flash": true,
"gemini-2.5-flash-lite": true,
"gemini-2.5-flash-thinking": true,
"gemini-3-flash": true, "gemini-3-flash": true,
"gemini-3-pro-low": true, "gemini-3-pro-low": true,
"gemini-3-pro-high": true, "gemini-3-pro-high": true,
...@@ -317,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{ ...@@ -317,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先) // Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀) // 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5)
var antigravityPrefixMapping = []struct { var antigravityPrefixMapping = []struct {
prefix string prefix string
target string target string
}{ }{
// 长前缀优先 // gemini-2.5 → gemini-3 映射(长前缀优先)
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image {"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash {"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx {"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx {"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet {"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high
{"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high
// gemini-3 前缀映射
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
// Claude 映射
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
{"claude-opus-4-5", "claude-opus-4-5-thinking"}, {"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet {"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
{"claude-sonnet-4", "claude-sonnet-4-5"}, {"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet {"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
{"claude-opus-4", "claude-opus-4-5-thinking"}, {"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
} }
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发 // AntigravityGatewayService 处理 Antigravity 平台的 API 转发
......
...@@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http. ...@@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.
return s.resp, s.err return s.resp, s.err
} }
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
return s.resp, s.err
}
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder() writer := httptest.NewRecorder()
......
...@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { ...@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5",
}, },
// 3. Gemini 透传 // 3. Gemini 2.5 → 3 映射
{ {
name: "Gemini透传 - gemini-2.5-flash", name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
requestedModel: "gemini-2.5-flash", requestedModel: "gemini-2.5-flash",
accountMapping: nil, accountMapping: nil,
expected: "gemini-2.5-flash", expected: "gemini-3-flash",
}, },
{ {
name: "Gemini透传 - gemini-2.5-pro", name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
requestedModel: "gemini-2.5-pro", requestedModel: "gemini-2.5-pro",
accountMapping: nil, accountMapping: nil,
expected: "gemini-2.5-pro", expected: "gemini-3-pro-high",
}, },
{ {
name: "Gemini透传 - gemini-future-model", name: "Gemini透传 - gemini-future-model",
......
...@@ -19,17 +19,19 @@ import ( ...@@ -19,17 +19,19 @@ import (
) )
var ( var (
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
) )
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
...@@ -47,6 +49,7 @@ type JWTClaims struct { ...@@ -47,6 +49,7 @@ type JWTClaims struct {
// AuthService 认证服务 // AuthService 认证服务
type AuthService struct { type AuthService struct {
userRepo UserRepository userRepo UserRepository
redeemRepo RedeemCodeRepository
cfg *config.Config cfg *config.Config
settingService *SettingService settingService *SettingService
emailService *EmailService emailService *EmailService
...@@ -58,6 +61,7 @@ type AuthService struct { ...@@ -58,6 +61,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
func NewAuthService( func NewAuthService(
userRepo UserRepository, userRepo UserRepository,
redeemRepo RedeemCodeRepository,
cfg *config.Config, cfg *config.Config,
settingService *SettingService, settingService *SettingService,
emailService *EmailService, emailService *EmailService,
...@@ -67,6 +71,7 @@ func NewAuthService( ...@@ -67,6 +71,7 @@ func NewAuthService(
) *AuthService { ) *AuthService {
return &AuthService{ return &AuthService{
userRepo: userRepo, userRepo: userRepo,
redeemRepo: redeemRepo,
cfg: cfg, cfg: cfg,
settingService: settingService, settingService: settingService,
emailService: emailService, emailService: emailService,
...@@ -78,11 +83,11 @@ func NewAuthService( ...@@ -78,11 +83,11 @@ func NewAuthService(
// Register 用户注册,返回token和用户 // Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "") return s.RegisterWithVerification(ctx, email, password, "", "", "")
} }
// RegisterWithVerification 用户注册(支持邮件验证优惠码),返回token和用户 // RegisterWithVerification 用户注册(支持邮件验证优惠码和邀请码),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
...@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrEmailReserved return "", nil, ErrEmailReserved
} }
// 检查是否需要邀请码
var invitationRedeemCode *RedeemCode
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
if invitationCode == "" {
return "", nil, ErrInvitationCodeRequired
}
// 验证邀请码
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err)
return "", nil, ErrInvitationCodeInvalid
}
// 检查类型和状态
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
return "", nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
}
// 检查是否需要邮件验证 // 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
// 如果邮件验证已开启但邮件服务未配置,拒绝注册 // 如果邮件验证已开启但邮件服务未配置,拒绝注册
...@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
// 邀请码标记失败不影响注册,只记录日志
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
}
}
// 应用优惠码(如果提供且功能已启用) // 应用优惠码(如果提供且功能已启用)
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
......
...@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E ...@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return NewAuthService( return NewAuthService(
repo, repo,
nil, // redeemRepo
cfg, cfg,
settingService, settingService,
emailService, emailService,
...@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi ...@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil) }, nil)
// 应返回服务不可用错误,而不是允许绕过验证 // 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
} }
...@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { ...@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true", SettingKeyEmailVerifyEnabled: "true",
}, cache) }, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired) require.ErrorIs(t, err, ErrEmailVerifyRequired)
} }
...@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { ...@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true", SettingKeyEmailVerifyEnabled: "true",
}, cache) }, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode) require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code") require.ErrorContains(t, err, "verify code")
} }
......
...@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken ...@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
return s.CalculateCost(model, tokens, multiplier) return s.CalculateCost(model, tokens, multiplier)
} }
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
//
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
// 范围内正常计费,范围外 × 2 计费
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
// 未启用长上下文计费,直接走正常计费
if threshold <= 0 || extraMultiplier <= 1 {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 计算总输入 token(缓存读取 + 新输入)
total := tokens.CacheReadTokens + tokens.InputTokens
if total <= threshold {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 拆分成范围内和范围外
var inRangeCacheTokens, inRangeInputTokens int
var outRangeCacheTokens, outRangeInputTokens int
if tokens.CacheReadTokens >= threshold {
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
inRangeCacheTokens = threshold
inRangeInputTokens = 0
outRangeCacheTokens = tokens.CacheReadTokens - threshold
outRangeInputTokens = tokens.InputTokens
} else {
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
inRangeCacheTokens = tokens.CacheReadTokens
inRangeInputTokens = threshold - tokens.CacheReadTokens
outRangeCacheTokens = 0
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
}
// 范围内部分:正常计费
inRangeTokens := UsageTokens{
InputTokens: inRangeInputTokens,
OutputTokens: tokens.OutputTokens, // 输出只算一次
CacheCreationTokens: tokens.CacheCreationTokens,
CacheReadTokens: inRangeCacheTokens,
}
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil {
return nil, err
}
// 范围外部分:× extraMultiplier 计费
outRangeTokens := UsageTokens{
InputTokens: outRangeInputTokens,
CacheReadTokens: outRangeCacheTokens,
}
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
if err != nil {
return inRangeCost, nil // 出错时返回范围内成本
}
// 合并成本
return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
}, nil
}
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配) // ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
func (s *BillingService) ListSupportedModels() []string { func (s *BillingService) ListSupportedModels() []string {
models := make([]string, 0) models := make([]string, 0)
......
...@@ -39,6 +39,7 @@ const ( ...@@ -39,6 +39,7 @@ const (
RedeemTypeBalance = domain.RedeemTypeBalance RedeemTypeBalance = domain.RedeemTypeBalance
RedeemTypeConcurrency = domain.RedeemTypeConcurrency RedeemTypeConcurrency = domain.RedeemTypeConcurrency
RedeemTypeSubscription = domain.RedeemTypeSubscription RedeemTypeSubscription = domain.RedeemTypeSubscription
RedeemTypeInvitation = domain.RedeemTypeInvitation
) )
// PromoCode status constants // PromoCode status constants
...@@ -72,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" ...@@ -72,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys // Setting keys
const ( const (
// 注册设置 // 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
// 邮件服务设置 // 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
......
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMergeAnthropicBeta(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"foo, oauth-2025-04-20,bar, foo",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
}
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
}
...@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte ...@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
return 0, nil return 0, nil
} }
func (m *mockGroupRepoForGateway) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func ptr[T any](v T) *T { func ptr[T any](v T) *T {
return &v return &v
} }
......
package service
import (
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
System: nil,
Messages: nil,
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
}
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
require.NotEmpty(t, got)
// Legacy format: user_{client}_account__session_{uuid}
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{
"account_uuid": "acc-uuid",
"claude_user_id": "clientid123",
"anthropic_user_id": "",
},
}
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
require.NotEmpty(t, got)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
...@@ -2,6 +2,7 @@ package service ...@@ -2,6 +2,7 @@ package service
import ( import (
"encoding/json" "encoding/json"
"strings"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) { ...@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
} }
func TestInjectClaudeCodePrompt(t *testing.T) { func TestInjectClaudeCodePrompt(t *testing.T) {
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
tests := []struct { tests := []struct {
name string name string
body string body string
...@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { ...@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system: "Custom prompt", system: "Custom prompt",
wantSystemLen: 2, wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom prompt", wantSecondText: claudePrefix + "\n\nCustom prompt",
}, },
{ {
name: "string system equals Claude Code prompt", name: "string system equals Claude Code prompt",
...@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { ...@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2 // Claude Code + Custom = 2
wantSystemLen: 2, wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom", wantSecondText: claudePrefix + "\n\nCustom",
}, },
{ {
name: "array system with existing Claude Code prompt (should dedupe)", name: "array system with existing Claude Code prompt (should dedupe)",
...@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { ...@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped) // Claude Code at start + Other = 2 (deduped)
wantSystemLen: 2, wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Other", wantSecondText: claudePrefix + "\n\nOther",
}, },
{ {
name: "empty array", name: "empty array",
......
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
in := "You are OpenCode, the best coding agent on the planet."
got := sanitizeSystemText(in)
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
}
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
in := "OpenCode and opencode are mentioned."
got := sanitizeToolDescription(in)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require.Equal(t, in, got)
}
This diff is collapsed.
...@@ -36,6 +36,11 @@ const ( ...@@ -36,6 +36,11 @@ const (
geminiRetryMaxDelay = 16 * time.Second geminiRetryMaxDelay = 16 * time.Second
) )
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
const geminiDummyThoughtSignature = "skip_thought_signature_validator"
type GeminiMessagesCompatService struct { type GeminiMessagesCompatService struct {
accountRepo AccountRepository accountRepo AccountRepository
groupRepo GroupRepository groupRepo GroupRepository
...@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if err != nil { if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
} }
geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq)
originalClaudeBody := body originalClaudeBody := body
proxyURL := "" proxyURL := ""
...@@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
} }
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
body = ensureGeminiFunctionCallThoughtSignatures(body)
mappedModel := originalModel mappedModel := originalModel
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel) mappedModel = account.GetMappedModel(originalModel)
...@@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 { ...@@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 {
return &ts return &ts
} }
func ensureGeminiFunctionCallThoughtSignatures(body []byte) []byte {
// Fast path: only run when functionCall is present.
if !bytes.Contains(body, []byte(`"functionCall"`)) {
return body
}
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return body
}
contentsAny, ok := payload["contents"].([]any)
if !ok || len(contentsAny) == 0 {
return body
}
modified := false
for _, c := range contentsAny {
cm, ok := c.(map[string]any)
if !ok {
continue
}
partsAny, ok := cm["parts"].([]any)
if !ok || len(partsAny) == 0 {
continue
}
for _, p := range partsAny {
pm, ok := p.(map[string]any)
if !ok || pm == nil {
continue
}
if fc, ok := pm["functionCall"].(map[string]any); !ok || fc == nil {
continue
}
ts, _ := pm["thoughtSignature"].(string)
if strings.TrimSpace(ts) == "" {
pm["thoughtSignature"] = geminiDummyThoughtSignature
modified = true
}
}
}
if !modified {
return body
}
b, err := json.Marshal(payload)
if err != nil {
return body
}
return b
}
func extractGeminiFinishReason(geminiResp map[string]any) string { func extractGeminiFinishReason(geminiResp map[string]any) string {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok { if cand, ok := candidates[0].(map[string]any); ok {
...@@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str ...@@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" { if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
toolUseIDToName[id] = name toolUseIDToName[id] = name
} }
signature, _ := bm["signature"].(string)
signature = strings.TrimSpace(signature)
if signature == "" {
signature = geminiDummyThoughtSignature
}
parts = append(parts, map[string]any{ parts = append(parts, map[string]any{
"thoughtSignature": signature,
"functionCall": map[string]any{ "functionCall": map[string]any{
"name": name, "name": name,
"args": bm["input"], "args": bm["input"],
......
package service package service
import ( import (
"encoding/json"
"strings"
"testing" "testing"
) )
...@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { ...@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
}) })
} }
} }
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001",
"max_tokens": 10,
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "hi"},
},
},
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{"type": "text", "text": "ok"},
map[string]any{
"type": "tool_use",
"id": "toolu_123",
"name": "default_api:write_file",
"input": map[string]any{"path": "a.txt", "content": "x"},
// no signature on purpose
},
},
},
},
"tools": []any{
map[string]any{
"name": "default_api:write_file",
"description": "write file",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{"path": map[string]any{"type": "string"}},
},
},
},
}
b, _ := json.Marshal(claudeReq)
out, err := convertClaudeMessagesToGeminiGenerateContent(b)
if err != nil {
t.Fatalf("convert failed: %v", err)
}
s := string(out)
if !strings.Contains(s, "\"functionCall\"") {
t.Fatalf("expected functionCall in output, got: %s", s)
}
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}
func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing.T) {
geminiReq := map[string]any{
"contents": []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{
"functionCall": map[string]any{
"name": "default_api:write_file",
"args": map[string]any{"path": "a.txt"},
},
},
},
},
},
}
b, _ := json.Marshal(geminiReq)
out := ensureGeminiFunctionCallThoughtSignatures(b)
s := string(out)
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}
...@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex ...@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
return 0, nil return 0, nil
} }
func (m *mockGroupRepoForGemini) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
var _ GroupRepository = (*mockGroupRepoForGemini)(nil) var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock // mockGatewayCacheForGemini Gemini 测试用的 cache mock
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment