Commit 7b83d6e7 authored by 陈曦's avatar 陈曦
Browse files

Merge remote-tracking branch 'upstream/main'

parents daa2e6df dbb248df
......@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
......@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
"integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens
"integer", // image_output_tokens
"numeric", // image_output_cost
"numeric", // input_cost
"numeric", // output_cost
"numeric", // cache_creation_cost
......@@ -77,6 +79,10 @@ var usageLogInsertArgTypes = [...]string{
"text", // inbound_endpoint
"text", // upstream_endpoint
"boolean", // cache_ttl_overridden
"bigint", // channel_id
"text", // model_mapping_chain
"text", // billing_tier
"text", // billing_mode
"timestamptz", // created_at
}
......@@ -326,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -350,14 +358,18 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
......@@ -758,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -782,10 +796,14 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*39)
args := make([]any, 0, len(keys)*47)
argPos := 1
for idx, key := range keys {
if idx > 0 {
......@@ -829,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -853,6 +873,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
)
SELECT
......@@ -871,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -895,6 +921,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
......@@ -953,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -977,10 +1009,14 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*40)
args := make([]any, 0, len(preparedList)*46)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
......@@ -1021,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1045,6 +1083,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
)
SELECT
......@@ -1063,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1087,6 +1131,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
......@@ -1113,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1137,14 +1187,18 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
......@@ -1176,6 +1230,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
channelID := nullInt64(log.ChannelID)
modelMappingChain := nullString(log.ModelMappingChain)
billingTier := nullString(log.BillingTier)
billingMode := nullString(log.BillingMode)
requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model)
......@@ -1208,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
......@@ -1232,6 +1292,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
inboundEndpoint,
upstreamEndpoint,
log.CacheTTLOverridden,
channelID,
modelMappingChain,
billingTier,
billingMode,
createdAt,
},
}
......@@ -2564,8 +2628,8 @@ type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
conditions := make([]string, 0, 8)
args := make([]any, 0, 8)
conditions := make([]string, 0, 9)
args := make([]any, 0, 9)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
......@@ -2589,6 +2653,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
......@@ -3096,6 +3164,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint)
}
if dim.UserID > 0 {
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
args = append(args, dim.UserID)
}
if dim.APIKeyID > 0 {
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
args = append(args, dim.APIKeyID)
}
if dim.AccountID > 0 {
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
args = append(args, dim.AccountID)
}
if dim.RequestType != nil {
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
args = append(args, *dim.RequestType)
}
if dim.Stream != nil {
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
args = append(args, *dim.Stream)
}
if dim.BillingType != nil {
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
args = append(args, *dim.BillingType)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 {
......@@ -3256,6 +3348,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
......@@ -3935,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
imageOutputTokens int
imageOutputCost float64
inputCost float64
outputCost float64
cacheCreationCost float64
......@@ -3959,6 +4057,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString
cacheTTLOverridden bool
channelID sql.NullInt64
modelMappingChain sql.NullString
billingTier sql.NullString
billingMode sql.NullString
createdAt time.Time
)
......@@ -3979,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&cacheReadTokens,
&cacheCreation5m,
&cacheCreation1h,
&imageOutputTokens,
&imageOutputCost,
&inputCost,
&outputCost,
&cacheCreationCost,
......@@ -4003,6 +4107,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&inboundEndpoint,
&upstreamEndpoint,
&cacheTTLOverridden,
&channelID,
&modelMappingChain,
&billingTier,
&billingMode,
&createdAt,
); err != nil {
return nil, err
......@@ -4021,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h,
ImageOutputTokens: imageOutputTokens,
ImageOutputCost: imageOutputCost,
InputCost: inputCost,
OutputCost: outputCost,
CacheCreationCost: cacheCreationCost,
......@@ -4087,6 +4197,19 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
if channelID.Valid {
value := channelID.Int64
log.ChannelID = &value
}
if modelMappingChain.Valid {
log.ModelMappingChain = &modelMappingChain.String
}
if billingTier.Valid {
log.BillingTier = &billingTier.String
}
if billingMode.Valid {
log.BillingMode = &billingMode.String
}
return log, nil
}
......
......@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
......@@ -80,6 +82,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
......@@ -129,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
......@@ -153,6 +161,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
......@@ -439,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
4, // cache_read_tokens
5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens
0, // image_output_tokens
0.0, // image_output_cost
0.1, // input_cost
0.2, // output_cost
0.3, // cache_creation_cost
......@@ -463,6 +477,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
......@@ -487,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
......@@ -506,6 +525,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
......@@ -530,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
......@@ -549,6 +573,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
......
......@@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository,
NewChannelRepository,
// Cache implementations
NewGatewayCache,
......
......@@ -87,6 +87,9 @@ func RegisterAdminRoutes(
// 定时测试计划
registerScheduledTestRoutes(admin, h)
// 渠道管理
registerChannelRoutes(admin, h)
}
}
......@@ -567,3 +570,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand
profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete)
}
}
func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels := admin.Group("/channels")
{
channels.GET("", h.Admin.Channel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create)
channels.PUT("/:id", h.Admin.Channel.Update)
channels.DELETE("/:id", h.Admin.Channel.Delete)
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestSelectAccountForModelWithExclusions_UsesFallbackGroupForChannelRestriction(t *testing.T) {
t.Parallel()
groupID := int64(10)
fallbackID := int64(11)
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{fallbackID},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
fallbackID: PlatformAnthropic,
}))
accountRepo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range accountRepo.accounts {
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
},
fallbackID: {
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
},
},
}
svc := &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
channelService: channelSvc,
cfg: testConfig(),
}
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(1), account.ID)
}
func TestSelectAccountWithLoadAwareness_UsesFallbackGroupForChannelRestriction(t *testing.T) {
t.Parallel()
groupID := int64(10)
fallbackID := int64(11)
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{fallbackID},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
fallbackID: PlatformAnthropic,
}))
accountRepo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range accountRepo.accounts {
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
},
fallbackID: {
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
},
},
}
svc := &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
channelService: channelSvc,
cfg: testConfig(),
}
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-sonnet-4-6", nil, "", 0)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment