Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
01ef7340
Commit
01ef7340
authored
Mar 14, 2026
by
Wang Lvyuan
Browse files
Merge remote-tracking branch 'origin/main' into openai-model-mapping-fix
parents
4e8615f2
e6d59216
Changes
82
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/bedrock_stream_test.go
0 → 100644
View file @
01ef7340
package
service
import
(
"bytes"
"encoding/base64"
"encoding/binary"
"hash/crc32"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func
TestExtractBedrockChunkData
(
t
*
testing
.
T
)
{
t
.
Run
(
"valid base64 payload"
,
func
(
t
*
testing
.
T
)
{
original
:=
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
b64
:=
base64
.
StdEncoding
.
EncodeToString
([]
byte
(
original
))
payload
:=
[]
byte
(
`{"bytes":"`
+
b64
+
`"}`
)
result
:=
extractBedrockChunkData
(
payload
)
require
.
NotNil
(
t
,
result
)
assert
.
JSONEq
(
t
,
original
,
string
(
result
))
})
t
.
Run
(
"empty bytes field"
,
func
(
t
*
testing
.
T
)
{
result
:=
extractBedrockChunkData
([]
byte
(
`{"bytes":""}`
))
assert
.
Nil
(
t
,
result
)
})
t
.
Run
(
"no bytes field"
,
func
(
t
*
testing
.
T
)
{
result
:=
extractBedrockChunkData
([]
byte
(
`{"other":"value"}`
))
assert
.
Nil
(
t
,
result
)
})
t
.
Run
(
"invalid base64"
,
func
(
t
*
testing
.
T
)
{
result
:=
extractBedrockChunkData
([]
byte
(
`{"bytes":"not-valid-base64!!!"}`
))
assert
.
Nil
(
t
,
result
)
})
}
func
TestTransformBedrockInvocationMetrics
(
t
*
testing
.
T
)
{
t
.
Run
(
"converts metrics to usage"
,
func
(
t
*
testing
.
T
)
{
input
:=
`{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result
:=
transformBedrockInvocationMetrics
([]
byte
(
input
))
// amazon-bedrock-invocationMetrics should be removed
assert
.
False
(
t
,
gjson
.
GetBytes
(
result
,
"amazon-bedrock-invocationMetrics"
)
.
Exists
())
// usage should be set
assert
.
Equal
(
t
,
int64
(
150
),
gjson
.
GetBytes
(
result
,
"usage.input_tokens"
)
.
Int
())
assert
.
Equal
(
t
,
int64
(
42
),
gjson
.
GetBytes
(
result
,
"usage.output_tokens"
)
.
Int
())
// original fields preserved
assert
.
Equal
(
t
,
"message_delta"
,
gjson
.
GetBytes
(
result
,
"type"
)
.
String
())
assert
.
Equal
(
t
,
"end_turn"
,
gjson
.
GetBytes
(
result
,
"delta.stop_reason"
)
.
String
())
})
t
.
Run
(
"no metrics present"
,
func
(
t
*
testing
.
T
)
{
input
:=
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
result
:=
transformBedrockInvocationMetrics
([]
byte
(
input
))
assert
.
JSONEq
(
t
,
input
,
string
(
result
))
})
t
.
Run
(
"does not overwrite existing usage"
,
func
(
t
*
testing
.
T
)
{
input
:=
`{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result
:=
transformBedrockInvocationMetrics
([]
byte
(
input
))
// metrics removed but existing usage preserved
assert
.
False
(
t
,
gjson
.
GetBytes
(
result
,
"amazon-bedrock-invocationMetrics"
)
.
Exists
())
assert
.
Equal
(
t
,
int64
(
100
),
gjson
.
GetBytes
(
result
,
"usage.output_tokens"
)
.
Int
())
})
}
func
TestExtractEventStreamHeaderValue
(
t
*
testing
.
T
)
{
// Build a header with :event-type = "chunk" (string type = 7)
buildStringHeader
:=
func
(
name
,
value
string
)
[]
byte
{
var
buf
bytes
.
Buffer
// name length (1 byte)
_
=
buf
.
WriteByte
(
byte
(
len
(
name
)))
// name
_
,
_
=
buf
.
WriteString
(
name
)
// value type (7 = string)
_
=
buf
.
WriteByte
(
7
)
// value length (2 bytes, big-endian)
_
=
binary
.
Write
(
&
buf
,
binary
.
BigEndian
,
uint16
(
len
(
value
)))
// value
_
,
_
=
buf
.
WriteString
(
value
)
return
buf
.
Bytes
()
}
t
.
Run
(
"find string header"
,
func
(
t
*
testing
.
T
)
{
headers
:=
buildStringHeader
(
":event-type"
,
"chunk"
)
assert
.
Equal
(
t
,
"chunk"
,
extractEventStreamHeaderValue
(
headers
,
":event-type"
))
})
t
.
Run
(
"header not found"
,
func
(
t
*
testing
.
T
)
{
headers
:=
buildStringHeader
(
":event-type"
,
"chunk"
)
assert
.
Equal
(
t
,
""
,
extractEventStreamHeaderValue
(
headers
,
":message-type"
))
})
t
.
Run
(
"multiple headers"
,
func
(
t
*
testing
.
T
)
{
var
buf
bytes
.
Buffer
_
,
_
=
buf
.
Write
(
buildStringHeader
(
":content-type"
,
"application/json"
))
_
,
_
=
buf
.
Write
(
buildStringHeader
(
":event-type"
,
"chunk"
))
_
,
_
=
buf
.
Write
(
buildStringHeader
(
":message-type"
,
"event"
))
headers
:=
buf
.
Bytes
()
assert
.
Equal
(
t
,
"chunk"
,
extractEventStreamHeaderValue
(
headers
,
":event-type"
))
assert
.
Equal
(
t
,
"application/json"
,
extractEventStreamHeaderValue
(
headers
,
":content-type"
))
assert
.
Equal
(
t
,
"event"
,
extractEventStreamHeaderValue
(
headers
,
":message-type"
))
})
t
.
Run
(
"empty headers"
,
func
(
t
*
testing
.
T
)
{
assert
.
Equal
(
t
,
""
,
extractEventStreamHeaderValue
([]
byte
{},
":event-type"
))
})
}
func
TestBedrockEventStreamDecoder
(
t
*
testing
.
T
)
{
crc32IeeeTab
:=
crc32
.
MakeTable
(
crc32
.
IEEE
)
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
buildFrame
:=
func
(
eventType
string
,
payload
[]
byte
)
[]
byte
{
// Build headers
var
headersBuf
bytes
.
Buffer
// :event-type header
_
=
headersBuf
.
WriteByte
(
byte
(
len
(
":event-type"
)))
_
,
_
=
headersBuf
.
WriteString
(
":event-type"
)
_
=
headersBuf
.
WriteByte
(
7
)
// string type
_
=
binary
.
Write
(
&
headersBuf
,
binary
.
BigEndian
,
uint16
(
len
(
eventType
)))
_
,
_
=
headersBuf
.
WriteString
(
eventType
)
// :message-type header
_
=
headersBuf
.
WriteByte
(
byte
(
len
(
":message-type"
)))
_
,
_
=
headersBuf
.
WriteString
(
":message-type"
)
_
=
headersBuf
.
WriteByte
(
7
)
_
=
binary
.
Write
(
&
headersBuf
,
binary
.
BigEndian
,
uint16
(
len
(
"event"
)))
_
,
_
=
headersBuf
.
WriteString
(
"event"
)
headers
:=
headersBuf
.
Bytes
()
headersLen
:=
uint32
(
len
(
headers
))
// total = 12 (prelude) + headers + payload + 4 (message_crc)
totalLen
:=
uint32
(
12
+
len
(
headers
)
+
len
(
payload
)
+
4
)
// Prelude: total_length(4) + headers_length(4)
var
preludeBuf
bytes
.
Buffer
_
=
binary
.
Write
(
&
preludeBuf
,
binary
.
BigEndian
,
totalLen
)
_
=
binary
.
Write
(
&
preludeBuf
,
binary
.
BigEndian
,
headersLen
)
preludeBytes
:=
preludeBuf
.
Bytes
()
preludeCRC
:=
crc32
.
Checksum
(
preludeBytes
,
crc32IeeeTab
)
// Build frame: prelude + prelude_crc + headers + payload
var
frame
bytes
.
Buffer
_
,
_
=
frame
.
Write
(
preludeBytes
)
_
=
binary
.
Write
(
&
frame
,
binary
.
BigEndian
,
preludeCRC
)
_
,
_
=
frame
.
Write
(
headers
)
_
,
_
=
frame
.
Write
(
payload
)
// Message CRC covers everything before itself
messageCRC
:=
crc32
.
Checksum
(
frame
.
Bytes
(),
crc32IeeeTab
)
_
=
binary
.
Write
(
&
frame
,
binary
.
BigEndian
,
messageCRC
)
return
frame
.
Bytes
()
}
t
.
Run
(
"decode chunk event"
,
func
(
t
*
testing
.
T
)
{
payload
:=
[]
byte
(
`{"bytes":"dGVzdA=="}`
)
// base64("test")
frame
:=
buildFrame
(
"chunk"
,
payload
)
decoder
:=
newBedrockEventStreamDecoder
(
bytes
.
NewReader
(
frame
))
result
,
err
:=
decoder
.
Decode
()
require
.
NoError
(
t
,
err
)
assert
.
Equal
(
t
,
payload
,
result
)
})
t
.
Run
(
"skip non-chunk events"
,
func
(
t
*
testing
.
T
)
{
// Write initial-response followed by chunk
var
buf
bytes
.
Buffer
_
,
_
=
buf
.
Write
(
buildFrame
(
"initial-response"
,
[]
byte
(
`{}`
)))
chunkPayload
:=
[]
byte
(
`{"bytes":"aGVsbG8="}`
)
_
,
_
=
buf
.
Write
(
buildFrame
(
"chunk"
,
chunkPayload
))
decoder
:=
newBedrockEventStreamDecoder
(
&
buf
)
result
,
err
:=
decoder
.
Decode
()
require
.
NoError
(
t
,
err
)
assert
.
Equal
(
t
,
chunkPayload
,
result
)
})
t
.
Run
(
"EOF on empty input"
,
func
(
t
*
testing
.
T
)
{
decoder
:=
newBedrockEventStreamDecoder
(
bytes
.
NewReader
(
nil
))
_
,
err
:=
decoder
.
Decode
()
assert
.
Equal
(
t
,
io
.
EOF
,
err
)
})
t
.
Run
(
"corrupted prelude CRC"
,
func
(
t
*
testing
.
T
)
{
frame
:=
buildFrame
(
"chunk"
,
[]
byte
(
`{"bytes":"dGVzdA=="}`
))
// Corrupt the prelude CRC (bytes 8-11)
frame
[
8
]
^=
0xFF
decoder
:=
newBedrockEventStreamDecoder
(
bytes
.
NewReader
(
frame
))
_
,
err
:=
decoder
.
Decode
()
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"prelude CRC mismatch"
)
})
t
.
Run
(
"corrupted message CRC"
,
func
(
t
*
testing
.
T
)
{
frame
:=
buildFrame
(
"chunk"
,
[]
byte
(
`{"bytes":"dGVzdA=="}`
))
// Corrupt the message CRC (last 4 bytes)
frame
[
len
(
frame
)
-
1
]
^=
0xFF
decoder
:=
newBedrockEventStreamDecoder
(
bytes
.
NewReader
(
frame
))
_
,
err
:=
decoder
.
Decode
()
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"message CRC mismatch"
)
})
t
.
Run
(
"castagnoli encoded frame is rejected"
,
func
(
t
*
testing
.
T
)
{
castagnoliTab
:=
crc32
.
MakeTable
(
crc32
.
Castagnoli
)
payload
:=
[]
byte
(
`{"bytes":"dGVzdA=="}`
)
var
headersBuf
bytes
.
Buffer
_
=
headersBuf
.
WriteByte
(
byte
(
len
(
":event-type"
)))
_
,
_
=
headersBuf
.
WriteString
(
":event-type"
)
_
=
headersBuf
.
WriteByte
(
7
)
_
=
binary
.
Write
(
&
headersBuf
,
binary
.
BigEndian
,
uint16
(
len
(
"chunk"
)))
_
,
_
=
headersBuf
.
WriteString
(
"chunk"
)
headers
:=
headersBuf
.
Bytes
()
headersLen
:=
uint32
(
len
(
headers
))
totalLen
:=
uint32
(
12
+
len
(
headers
)
+
len
(
payload
)
+
4
)
var
preludeBuf
bytes
.
Buffer
_
=
binary
.
Write
(
&
preludeBuf
,
binary
.
BigEndian
,
totalLen
)
_
=
binary
.
Write
(
&
preludeBuf
,
binary
.
BigEndian
,
headersLen
)
preludeBytes
:=
preludeBuf
.
Bytes
()
var
frame
bytes
.
Buffer
_
,
_
=
frame
.
Write
(
preludeBytes
)
_
=
binary
.
Write
(
&
frame
,
binary
.
BigEndian
,
crc32
.
Checksum
(
preludeBytes
,
castagnoliTab
))
_
,
_
=
frame
.
Write
(
headers
)
_
,
_
=
frame
.
Write
(
payload
)
_
=
binary
.
Write
(
&
frame
,
binary
.
BigEndian
,
crc32
.
Checksum
(
frame
.
Bytes
(),
castagnoliTab
))
decoder
:=
newBedrockEventStreamDecoder
(
bytes
.
NewReader
(
frame
.
Bytes
()))
_
,
err
:=
decoder
.
Decode
()
require
.
Error
(
t
,
err
)
assert
.
Contains
(
t
,
err
.
Error
(),
"prelude CRC mismatch"
)
})
}
func
TestBuildBedrockURL
(
t
*
testing
.
T
)
{
t
.
Run
(
"stream URL with colon in model ID"
,
func
(
t
*
testing
.
T
)
{
url
:=
BuildBedrockURL
(
"us-east-1"
,
"us.anthropic.claude-opus-4-5-20251101-v1:0"
,
true
)
assert
.
Equal
(
t
,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream"
,
url
)
})
t
.
Run
(
"non-stream URL with colon in model ID"
,
func
(
t
*
testing
.
T
)
{
url
:=
BuildBedrockURL
(
"eu-west-1"
,
"eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
,
false
)
assert
.
Equal
(
t
,
"https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke"
,
url
)
})
t
.
Run
(
"model ID without colon"
,
func
(
t
*
testing
.
T
)
{
url
:=
BuildBedrockURL
(
"us-east-1"
,
"us.anthropic.claude-sonnet-4-6"
,
true
)
assert
.
Equal
(
t
,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream"
,
url
)
})
}
backend/internal/service/dashboard_aggregation_service.go
View file @
01ef7340
...
...
@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
UpdateAggregationWatermark
(
ctx
context
.
Context
,
aggregatedAt
time
.
Time
)
error
CleanupAggregates
(
ctx
context
.
Context
,
hourlyCutoff
,
dailyCutoff
time
.
Time
)
error
CleanupUsageLogs
(
ctx
context
.
Context
,
cutoff
time
.
Time
)
error
CleanupUsageBillingDedup
(
ctx
context
.
Context
,
cutoff
time
.
Time
)
error
EnsureUsageLogsPartitions
(
ctx
context
.
Context
,
now
time
.
Time
)
error
}
...
...
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
hourlyCutoff
:=
now
.
AddDate
(
0
,
0
,
-
s
.
cfg
.
Retention
.
HourlyDays
)
dailyCutoff
:=
now
.
AddDate
(
0
,
0
,
-
s
.
cfg
.
Retention
.
DailyDays
)
usageCutoff
:=
now
.
AddDate
(
0
,
0
,
-
s
.
cfg
.
Retention
.
UsageLogsDays
)
dedupCutoff
:=
now
.
AddDate
(
0
,
0
,
-
s
.
cfg
.
Retention
.
UsageBillingDedupDays
)
aggErr
:=
s
.
repo
.
CleanupAggregates
(
ctx
,
hourlyCutoff
,
dailyCutoff
)
if
aggErr
!=
nil
{
...
...
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
if
usageErr
!=
nil
{
logger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] usage_logs 保留清理失败: %v"
,
usageErr
)
}
if
aggErr
==
nil
&&
usageErr
==
nil
{
dedupErr
:=
s
.
repo
.
CleanupUsageBillingDedup
(
ctx
,
dedupCutoff
)
if
dedupErr
!=
nil
{
logger
.
LegacyPrintf
(
"service.dashboard_aggregation"
,
"[DashboardAggregation] usage_billing_dedup 保留清理失败: %v"
,
dedupErr
)
}
if
aggErr
==
nil
&&
usageErr
==
nil
&&
dedupErr
==
nil
{
s
.
lastRetentionCleanup
.
Store
(
now
)
}
}
...
...
backend/internal/service/dashboard_aggregation_service_test.go
View file @
01ef7340
...
...
@@ -12,12 +12,18 @@ import (
type
dashboardAggregationRepoTestStub
struct
{
aggregateCalls
int
recomputeCalls
int
cleanupUsageCalls
int
cleanupDedupCalls
int
ensurePartitionCalls
int
lastStart
time
.
Time
lastEnd
time
.
Time
watermark
time
.
Time
aggregateErr
error
cleanupAggregatesErr
error
cleanupUsageErr
error
cleanupDedupErr
error
ensurePartitionErr
error
}
func
(
s
*
dashboardAggregationRepoTestStub
)
AggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
...
...
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
}
func
(
s
*
dashboardAggregationRepoTestStub
)
RecomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
s
.
recomputeCalls
++
return
s
.
AggregateRange
(
ctx
,
start
,
end
)
}
...
...
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
}
func
(
s
*
dashboardAggregationRepoTestStub
)
CleanupUsageLogs
(
ctx
context
.
Context
,
cutoff
time
.
Time
)
error
{
s
.
cleanupUsageCalls
++
return
s
.
cleanupUsageErr
}
func
(
s
*
dashboardAggregationRepoTestStub
)
CleanupUsageBillingDedup
(
ctx
context
.
Context
,
cutoff
time
.
Time
)
error
{
s
.
cleanupDedupCalls
++
return
s
.
cleanupDedupErr
}
func
(
s
*
dashboardAggregationRepoTestStub
)
EnsureUsageLogsPartitions
(
ctx
context
.
Context
,
now
time
.
Time
)
error
{
return
nil
s
.
ensurePartitionCalls
++
return
s
.
ensurePartitionErr
}
func
TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart
(
t
*
testing
.
T
)
{
...
...
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
svc
.
maybeCleanupRetention
(
context
.
Background
(),
time
.
Now
()
.
UTC
())
require
.
Nil
(
t
,
svc
.
lastRetentionCleanup
.
Load
())
require
.
Equal
(
t
,
1
,
repo
.
cleanupUsageCalls
)
require
.
Equal
(
t
,
1
,
repo
.
cleanupDedupCalls
)
}
func
TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardAggregationRepoTestStub
{
cleanupDedupErr
:
errors
.
New
(
"dedup cleanup failed"
)}
svc
:=
&
DashboardAggregationService
{
repo
:
repo
,
cfg
:
config
.
DashboardAggregationConfig
{
Retention
:
config
.
DashboardAggregationRetentionConfig
{
UsageLogsDays
:
1
,
HourlyDays
:
1
,
DailyDays
:
1
,
},
},
}
svc
.
maybeCleanupRetention
(
context
.
Background
(),
time
.
Now
()
.
UTC
())
require
.
Nil
(
t
,
svc
.
lastRetentionCleanup
.
Load
())
require
.
Equal
(
t
,
1
,
repo
.
cleanupDedupCalls
)
}
func
TestDashboardAggregationService_PartitionFailure_DoesNotAggregate
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardAggregationRepoTestStub
{
ensurePartitionErr
:
errors
.
New
(
"partition failed"
)}
svc
:=
&
DashboardAggregationService
{
repo
:
repo
,
cfg
:
config
.
DashboardAggregationConfig
{
Enabled
:
true
,
IntervalSeconds
:
60
,
LookbackSeconds
:
120
,
Retention
:
config
.
DashboardAggregationRetentionConfig
{
UsageLogsDays
:
1
,
UsageBillingDedupDays
:
2
,
HourlyDays
:
1
,
DailyDays
:
1
,
},
},
}
svc
.
runScheduledAggregation
()
require
.
Equal
(
t
,
1
,
repo
.
ensurePartitionCalls
)
require
.
Equal
(
t
,
1
,
repo
.
aggregateCalls
)
}
func
TestDashboardAggregationService_TriggerBackfill_TooLarge
(
t
*
testing
.
T
)
{
...
...
backend/internal/service/dashboard_service.go
View file @
01ef7340
...
...
@@ -327,6 +327,14 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return
trend
,
nil
}
func
(
s
*
DashboardService
)
GetUserSpendingRanking
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
limit
int
)
(
*
usagestats
.
UserSpendingRankingResponse
,
error
)
{
ranking
,
err
:=
s
.
usageRepo
.
GetUserSpendingRanking
(
ctx
,
startTime
,
endTime
,
limit
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user spending ranking: %w"
,
err
)
}
return
ranking
,
nil
}
func
(
s
*
DashboardService
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchUserUsageStats
(
ctx
,
userIDs
,
startTime
,
endTime
)
if
err
!=
nil
{
...
...
backend/internal/service/dashboard_service_test.go
View file @
01ef7340
...
...
@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
CleanupUsageBillingDedup
(
ctx
context
.
Context
,
cutoff
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
EnsureUsageLogsPartitions
(
ctx
context
.
Context
,
now
time
.
Time
)
error
{
return
nil
}
...
...
backend/internal/service/domain_constants.go
View file @
01ef7340
...
...
@@ -29,10 +29,12 @@ const (
// Account type constants
const
(
AccountTypeOAuth
=
domain
.
AccountTypeOAuth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
domain
.
AccountTypeSetupToken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
AccountTypeUpstream
=
domain
.
AccountTypeUpstream
// 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeOAuth
=
domain
.
AccountTypeOAuth
// OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken
=
domain
.
AccountTypeSetupToken
// Setup Token类型账号(inference only scope)
AccountTypeAPIKey
=
domain
.
AccountTypeAPIKey
// API Key类型账号
AccountTypeUpstream
=
domain
.
AccountTypeUpstream
// 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock
=
domain
.
AccountTypeBedrock
// AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
AccountTypeBedrockAPIKey
=
domain
.
AccountTypeBedrockAPIKey
// AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
)
// Redeem type constants
...
...
backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
View file @
01ef7340
...
...
@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
},
}
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
httpUpstream
:
upstream
,
rateLimitService
:
&
RateLimitService
{},
deferredService
:
&
DeferredService
{},
billingCacheService
:
nil
,
}
svc
:=
&
GatewayService
{
cfg
:
cfg
,
responseHeaderFilter
:
compileResponseHeaderFilter
(
cfg
),
httpUpstream
:
upstream
,
rateLimitService
:
&
RateLimitService
{},
deferredService
:
&
DeferredService
{},
billingCacheService
:
nil
,
}
account
:=
&
Account
{
...
...
@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
},
}
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
httpUpstream
:
upstream
,
rateLimitService
:
&
RateLimitService
{},
}
svc
:=
&
GatewayService
{
cfg
:
cfg
,
responseHeaderFilter
:
compileResponseHeaderFilter
(
cfg
),
httpUpstream
:
upstream
,
rateLimitService
:
&
RateLimitService
{},
}
account
:=
&
Account
{
...
...
@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
require
.
Equal
(
t
,
5
,
result
.
usage
.
OutputTokens
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
svc
:=
&
GatewayService
{
cfg
:
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
},
rateLimitService
:
&
RateLimitService
{},
}
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
strings
.
Join
([]
string
{
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`
,
""
,
`data: {"type":"message_delta","usage":{"output_tokens":5}}`
,
""
,
},
"
\n
"
))),
}
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"missing terminal event"
)
require
.
NotNil
(
t
,
result
)
}
func
TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
...
...
@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
_
=
pr
.
Close
()
<-
done
require
.
NoError
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"stream usage incomplete after timeout"
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
Equal
(
t
,
9
,
result
.
usage
.
InputTokens
)
...
...
@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
}
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
3
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
require
.
NoError
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"stream usage incomplete"
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
}
...
...
@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
}
result
,
err
:=
svc
.
handleStreamingResponseAnthropicAPIKeyPassthrough
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
4
},
time
.
Now
(),
"claude-3-7-sonnet-20250219"
)
require
.
NoError
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"stream usage incomplete after disconnect"
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
Equal
(
t
,
8
,
result
.
usage
.
InputTokens
)
...
...
backend/internal/service/gateway_record_usage_test.go
0 → 100644
View file @
01ef7340
//go:build unit
package
service
import
(
"context"
"errors"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func
newGatewayRecordUsageServiceForTest
(
usageRepo
UsageLogRepository
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
)
*
GatewayService
{
cfg
:=
&
config
.
Config
{}
cfg
.
Default
.
RateMultiplier
=
1.1
return
NewGatewayService
(
nil
,
nil
,
usageRepo
,
nil
,
userRepo
,
subRepo
,
nil
,
nil
,
cfg
,
nil
,
nil
,
NewBillingService
(
cfg
,
nil
),
nil
,
&
BillingCacheService
{},
nil
,
nil
,
&
DeferredService
{},
nil
,
nil
,
nil
,
nil
,
nil
,
)
}
func
newGatewayRecordUsageServiceWithBillingRepoForTest
(
usageRepo
UsageLogRepository
,
billingRepo
UsageBillingRepository
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
)
*
GatewayService
{
svc
:=
newGatewayRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
)
svc
.
usageBillingRepo
=
billingRepo
return
svc
}
type
openAIRecordUsageBestEffortLogRepoStub
struct
{
UsageLogRepository
bestEffortErr
error
createErr
error
bestEffortCalls
int
createCalls
int
lastLog
*
UsageLog
lastCtxErr
error
}
func
(
s
*
openAIRecordUsageBestEffortLogRepoStub
)
CreateBestEffort
(
ctx
context
.
Context
,
log
*
UsageLog
)
error
{
s
.
bestEffortCalls
++
s
.
lastLog
=
log
s
.
lastCtxErr
=
ctx
.
Err
()
return
s
.
bestEffortErr
}
func
(
s
*
openAIRecordUsageBestEffortLogRepoStub
)
Create
(
ctx
context
.
Context
,
log
*
UsageLog
)
(
bool
,
error
)
{
s
.
createCalls
++
s
.
lastLog
=
log
s
.
lastCtxErr
=
ctx
.
Err
()
return
false
,
s
.
createErr
}
func
TestGatewayServiceRecordUsage_BillingUsesDetachedContext
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
false
,
err
:
context
.
DeadlineExceeded
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
quotaSvc
:=
&
openAIRecordUsageAPIKeyQuotaStub
{}
svc
:=
newGatewayRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
)
reqCtx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
err
:=
svc
.
RecordUsage
(
reqCtx
,
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
"gateway_detached_ctx"
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
501
,
Quota
:
100
,
},
User
:
&
User
{
ID
:
601
},
Account
:
&
Account
{
ID
:
701
},
APIKeyService
:
quotaSvc
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
usageRepo
.
calls
)
require
.
Equal
(
t
,
1
,
userRepo
.
deductCalls
)
require
.
NoError
(
t
,
userRepo
.
lastCtxErr
)
require
.
Equal
(
t
,
1
,
quotaSvc
.
quotaCalls
)
require
.
NoError
(
t
,
quotaSvc
.
lastQuotaCtxErr
)
}
func
TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
svc
:=
newGatewayRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
&
openAIRecordUsageUserRepoStub
{},
&
openAIRecordUsageSubRepoStub
{})
payloadHash
:=
HashUsageRequestPayload
([]
byte
(
`{"messages":[{"role":"user","content":"hello"}]}`
))
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
"gateway_payload_hash"
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
501
,
Quota
:
100
},
User
:
&
User
{
ID
:
601
},
Account
:
&
Account
{
ID
:
701
},
RequestPayloadHash
:
payloadHash
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
billingRepo
.
lastCmd
)
require
.
Equal
(
t
,
payloadHash
,
billingRepo
.
lastCmd
.
RequestPayloadHash
)
}
func
TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
svc
:=
newGatewayRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
&
openAIRecordUsageUserRepoStub
{},
&
openAIRecordUsageSubRepoStub
{})
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
RequestID
,
"req-local-123"
)
err
:=
svc
.
RecordUsage
(
ctx
,
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
"gateway_payload_fallback"
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
501
,
Quota
:
100
},
User
:
&
User
{
ID
:
601
},
Account
:
&
Account
{
ID
:
701
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
billingRepo
.
lastCmd
)
require
.
Equal
(
t
,
"local:req-local-123"
,
billingRepo
.
lastCmd
.
RequestPayloadHash
)
}
func
TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
false
,
err
:
MarkUsageLogCreateNotPersisted
(
context
.
Canceled
)}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
quotaSvc
:=
&
openAIRecordUsageAPIKeyQuotaStub
{}
svc
:=
newGatewayRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
"gateway_not_persisted"
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
503
,
Quota
:
100
,
},
User
:
&
User
{
ID
:
603
},
Account
:
&
Account
{
ID
:
703
},
APIKeyService
:
quotaSvc
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
usageRepo
.
calls
)
require
.
Equal
(
t
,
1
,
userRepo
.
deductCalls
)
require
.
Equal
(
t
,
1
,
quotaSvc
.
quotaCalls
)
}
func
TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
false
,
err
:
context
.
DeadlineExceeded
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
quotaSvc
:=
&
openAIRecordUsageAPIKeyQuotaStub
{}
svc
:=
newGatewayRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
)
reqCtx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
err
:=
svc
.
RecordUsageWithLongContext
(
reqCtx
,
&
RecordUsageLongContextInput
{
Result
:
&
ForwardResult
{
RequestID
:
"gateway_long_context_detached_ctx"
,
Usage
:
ClaudeUsage
{
InputTokens
:
12
,
OutputTokens
:
8
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
502
,
Quota
:
100
,
},
User
:
&
User
{
ID
:
602
},
Account
:
&
Account
{
ID
:
702
},
LongContextThreshold
:
200000
,
LongContextMultiplier
:
2
,
APIKeyService
:
quotaSvc
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
usageRepo
.
calls
)
require
.
Equal
(
t
,
1
,
userRepo
.
deductCalls
)
require
.
NoError
(
t
,
userRepo
.
lastCtxErr
)
require
.
Equal
(
t
,
1
,
quotaSvc
.
quotaCalls
)
require
.
NoError
(
t
,
quotaSvc
.
lastQuotaCtxErr
)
}
func
TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newGatewayRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
)
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
RequestID
,
"gateway-local-fallback"
)
err
:=
svc
.
RecordUsage
(
ctx
,
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
""
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
504
},
User
:
&
User
{
ID
:
604
},
Account
:
&
Account
{
ID
:
704
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
"local:gateway-local-fallback"
,
usageRepo
.
lastLog
.
RequestID
)
}
func
TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
svc
:=
newGatewayRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
&
openAIRecordUsageUserRepoStub
{},
&
openAIRecordUsageSubRepoStub
{})
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
ClientRequestID
,
"client-stable-123"
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
RequestID
,
"req-local-ignored"
)
err
:=
svc
.
RecordUsage
(
ctx
,
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
"upstream-volatile-456"
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
506
},
User
:
&
User
{
ID
:
606
},
Account
:
&
Account
{
ID
:
706
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
billingRepo
.
lastCmd
)
require
.
Equal
(
t
,
"client:client-stable-123"
,
billingRepo
.
lastCmd
.
RequestID
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
"client:client-stable-123"
,
usageRepo
.
lastLog
.
RequestID
)
}
func
TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
svc
:=
newGatewayRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
&
openAIRecordUsageUserRepoStub
{},
&
openAIRecordUsageSubRepoStub
{})
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
""
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
507
},
User
:
&
User
{
ID
:
607
},
Account
:
&
Account
{
ID
:
707
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
billingRepo
.
lastCmd
)
require
.
True
(
t
,
strings
.
HasPrefix
(
billingRepo
.
lastCmd
.
RequestID
,
"generated:"
))
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
billingRepo
.
lastCmd
.
RequestID
,
usageRepo
.
lastLog
.
RequestID
)
}
func
TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageBestEffortLogRepoStub
{
bestEffortErr
:
MarkUsageLogCreateDropped
(
errors
.
New
(
"usage log best-effort queue full"
)),
}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
svc
:=
newGatewayRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
&
openAIRecordUsageUserRepoStub
{},
&
openAIRecordUsageSubRepoStub
{})
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
"gateway_drop_usage_log"
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
508
},
User
:
&
User
{
ID
:
608
},
Account
:
&
Account
{
ID
:
708
},
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
usageRepo
.
bestEffortCalls
)
require
.
Equal
(
t
,
0
,
usageRepo
.
createCalls
)
}
func
TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
err
:
context
.
DeadlineExceeded
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newGatewayRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
userRepo
,
subRepo
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
"gateway_billing_fail"
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
505
},
User
:
&
User
{
ID
:
605
},
Account
:
&
Account
{
ID
:
705
},
})
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
1
,
billingRepo
.
calls
)
require
.
Equal
(
t
,
0
,
usageRepo
.
calls
)
}
backend/internal/service/gateway_service.go
View file @
01ef7340
...
...
@@ -50,6 +50,7 @@ const (
defaultUserGroupRateCacheTTL
=
30
*
time
.
Second
defaultModelsListCacheTTL
=
15
*
time
.
Second
postUsageBillingTimeout
=
15
*
time
.
Second
)
const
(
...
...
@@ -106,6 +107,36 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
return
modelsListCacheHitTotal
.
Load
(),
modelsListCacheMissTotal
.
Load
(),
modelsListCacheStoreTotal
.
Load
()
}
func
openAIStreamEventIsTerminal
(
data
string
)
bool
{
trimmed
:=
strings
.
TrimSpace
(
data
)
if
trimmed
==
""
{
return
false
}
if
trimmed
==
"[DONE]"
{
return
true
}
switch
gjson
.
Get
(
trimmed
,
"type"
)
.
String
()
{
case
"response.completed"
,
"response.done"
,
"response.failed"
:
return
true
default
:
return
false
}
}
func
anthropicStreamEventIsTerminal
(
eventName
,
data
string
)
bool
{
if
strings
.
EqualFold
(
strings
.
TrimSpace
(
eventName
),
"message_stop"
)
{
return
true
}
trimmed
:=
strings
.
TrimSpace
(
data
)
if
trimmed
==
""
{
return
false
}
if
trimmed
==
"[DONE]"
{
return
true
}
return
gjson
.
Get
(
trimmed
,
"type"
)
.
String
()
==
"message_stop"
}
func
cloneStringSlice
(
src
[]
string
)
[]
string
{
if
len
(
src
)
==
0
{
return
nil
...
...
@@ -504,6 +535,7 @@ type GatewayService struct {
accountRepo
AccountRepository
groupRepo
GroupRepository
usageLogRepo
UsageLogRepository
usageBillingRepo
UsageBillingRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
userGroupRateRepo
UserGroupRateRepository
...
...
@@ -537,6 +569,7 @@ func NewGatewayService(
accountRepo
AccountRepository
,
groupRepo
GroupRepository
,
usageLogRepo
UsageLogRepository
,
usageBillingRepo
UsageBillingRepository
,
userRepo
UserRepository
,
userSubRepo
UserSubscriptionRepository
,
userGroupRateRepo
UserGroupRateRepository
,
...
...
@@ -563,6 +596,7 @@ func NewGatewayService(
accountRepo
:
accountRepo
,
groupRepo
:
groupRepo
,
usageLogRepo
:
usageLogRepo
,
usageBillingRepo
:
usageBillingRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
userGroupRateRepo
:
userGroupRateRepo
,
...
...
@@ -3336,6 +3370,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
if
account
.
Platform
==
PlatformSora
{
return
s
.
isSoraModelSupportedByAccount
(
account
,
requestedModel
)
}
if
account
.
IsBedrock
()
{
_
,
ok
:=
ResolveBedrockModelID
(
account
,
requestedModel
)
return
ok
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
!=
AccountTypeAPIKey
{
requestedModel
=
claude
.
NormalizeModelID
(
requestedModel
)
...
...
@@ -3493,6 +3531,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
return
""
,
""
,
errors
.
New
(
"api_key not found in credentials"
)
}
return
apiKey
,
"apikey"
,
nil
case
AccountTypeBedrock
:
return
""
,
"bedrock"
,
nil
// Bedrock 使用 SigV4 签名,不需要 token
case
AccountTypeBedrockAPIKey
:
return
""
,
"bedrock-apikey"
,
nil
// Bedrock API Key 使用 Bearer Token,由 forwardBedrock 处理
default
:
return
""
,
""
,
fmt
.
Errorf
(
"unsupported account type: %s"
,
account
.
Type
)
}
...
...
@@ -3948,6 +3990,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return
s
.
forwardAnthropicAPIKeyPassthrough
(
ctx
,
c
,
account
,
passthroughBody
,
passthroughModel
,
parsed
.
Stream
,
startTime
)
}
if
account
!=
nil
&&
account
.
IsBedrock
()
{
return
s
.
forwardBedrock
(
ctx
,
c
,
account
,
parsed
,
startTime
)
}
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
if
account
.
Platform
==
PlatformAnthropic
&&
c
!=
nil
{
...
...
@@ -4049,7 +4095,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
upstreamCtx
,
releaseUpstreamCtx
:=
detachStreamUpstreamContext
(
ctx
,
reqStream
)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
upstreamCtx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
releaseUpstreamCtx
()
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -4127,7 +4175,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
retryCtx
,
releaseRetryCtx
:=
detachStreamUpstreamContext
(
ctx
,
reqStream
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
retryCtx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
releaseRetryCtx
()
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
...
...
@@ -4159,7 +4209,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
retryCtx2
,
releaseRetryCtx2
:=
detachStreamUpstreamContext
(
ctx
,
reqStream
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
retryCtx2
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
releaseRetryCtx2
()
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
DoWithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr2
==
nil
{
...
...
@@ -4226,7 +4278,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
rectifiedBody
,
applied
:=
RectifyThinkingBudget
(
body
)
if
applied
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)"
,
account
.
ID
,
BudgetRectifyBudgetTokens
,
BudgetRectifyMaxTokens
)
budgetRetryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
rectifiedBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
budgetRetryCtx
,
releaseBudgetRetryCtx
:=
detachStreamUpstreamContext
(
ctx
,
reqStream
)
budgetRetryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
budgetRetryCtx
,
c
,
account
,
rectifiedBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
releaseBudgetRetryCtx
()
if
buildErr
==
nil
{
budgetRetryResp
,
retryErr
:=
s
.
httpUpstream
.
DoWithTLS
(
budgetRetryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
())
if
retryErr
==
nil
{
...
...
@@ -4498,7 +4552,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
var
resp
*
http
.
Response
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
upstreamReq
,
err
:=
s
.
buildUpstreamRequestAnthropicAPIKeyPassthrough
(
ctx
,
c
,
account
,
body
,
token
)
upstreamCtx
,
releaseUpstreamCtx
:=
detachStreamUpstreamContext
(
ctx
,
reqStream
)
upstreamReq
,
err
:=
s
.
buildUpstreamRequestAnthropicAPIKeyPassthrough
(
upstreamCtx
,
c
,
account
,
body
,
token
)
releaseUpstreamCtx
()
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -4774,6 +4830,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
clientDisconnected
:=
false
sawTerminalEvent
:=
false
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
...
...
@@ -4836,17 +4893,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
flusher
.
Flush
()
}
if
!
sawTerminalEvent
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
clientDisconnected
},
fmt
.
Errorf
(
"stream usage incomplete: missing terminal event"
)
}
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
clientDisconnected
},
nil
}
if
ev
.
err
!=
nil
{
if
sawTerminalEvent
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
clientDisconnected
},
nil
}
if
clientDisconnected
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v"
,
account
.
ID
,
ev
.
err
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
fmt
.
Errorf
(
"stream usage incomplete after disconnect: %w"
,
ev
.
err
)
}
if
errors
.
Is
(
ev
.
err
,
context
.
Canceled
)
||
errors
.
Is
(
ev
.
err
,
context
.
DeadlineExceeded
)
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v"
,
account
.
ID
,
resp
.
Header
.
Get
(
"x-request-id"
),
ev
.
err
,
ctx
.
Err
())
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
fmt
.
Errorf
(
"stream usage incomplete: %w"
,
ev
.
err
)
}
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
...
...
@@ -4858,11 +4918,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
line
:=
ev
.
line
if
data
,
ok
:=
extractAnthropicSSEDataLine
(
line
);
ok
{
trimmed
:=
strings
.
TrimSpace
(
data
)
if
anthropicStreamEventIsTerminal
(
""
,
trimmed
)
{
sawTerminalEvent
=
true
}
if
firstTokenMs
==
nil
&&
trimmed
!=
""
&&
trimmed
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
s
.
parseSSEUsagePassthrough
(
data
,
usage
)
}
else
{
trimmed
:=
strings
.
TrimSpace
(
line
)
if
strings
.
HasPrefix
(
trimmed
,
"event:"
)
&&
anthropicStreamEventIsTerminal
(
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"event:"
)),
""
)
{
sawTerminalEvent
=
true
}
}
if
!
clientDisconnected
{
...
...
@@ -4884,8 +4952,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
continue
}
if
clientDisconnected
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s"
,
account
.
ID
,
model
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
fmt
.
Errorf
(
"stream usage incomplete after timeout"
)
}
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
model
,
streamInterval
)
if
s
.
rateLimitService
!=
nil
{
...
...
@@ -5068,6 +5135,366 @@ func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header,
}
}
// forwardBedrock 转发请求到 AWS Bedrock
func
(
s
*
GatewayService
)
forwardBedrock
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
parsed
*
ParsedRequest
,
startTime
time
.
Time
,
)
(
*
ForwardResult
,
error
)
{
reqModel
:=
parsed
.
Model
reqStream
:=
parsed
.
Stream
body
:=
parsed
.
Body
region
:=
bedrockRuntimeRegion
(
account
)
mappedModel
,
ok
:=
ResolveBedrockModelID
(
account
,
reqModel
)
if
!
ok
{
return
nil
,
fmt
.
Errorf
(
"unsupported bedrock model: %s"
,
reqModel
)
}
if
mappedModel
!=
reqModel
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Bedrock] Model mapping: %s -> %s (account: %s)"
,
reqModel
,
mappedModel
,
account
.
Name
)
}
betaHeader
:=
""
if
c
!=
nil
&&
c
.
Request
!=
nil
{
betaHeader
=
c
.
GetHeader
(
"anthropic-beta"
)
}
// 准备请求体(注入 anthropic_version/anthropic_beta,移除 Bedrock 不支持的字段,清理 cache_control)
betaTokens
,
err
:=
s
.
resolveBedrockBetaTokensForRequest
(
ctx
,
account
,
betaHeader
,
body
,
mappedModel
)
if
err
!=
nil
{
return
nil
,
err
}
bedrockBody
,
err
:=
PrepareBedrockRequestBodyWithTokens
(
body
,
mappedModel
,
betaTokens
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"prepare bedrock request body: %w"
,
err
)
}
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Bedrock] 命中 Bedrock 分支: account=%d name=%s model=%s->%s stream=%v"
,
account
.
ID
,
account
.
Name
,
reqModel
,
mappedModel
,
reqStream
)
// 根据账号类型选择认证方式
var
signer
*
BedrockSigner
var
bedrockAPIKey
string
if
account
.
IsBedrockAPIKey
()
{
bedrockAPIKey
=
account
.
GetCredential
(
"api_key"
)
if
bedrockAPIKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"api_key not found in bedrock-apikey credentials"
)
}
}
else
{
signer
,
err
=
NewBedrockSignerFromAccount
(
account
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create bedrock signer: %w"
,
err
)
}
}
// 执行上游请求(含重试)
resp
,
err
:=
s
.
executeBedrockUpstream
(
ctx
,
c
,
account
,
bedrockBody
,
mappedModel
,
region
,
reqStream
,
signer
,
bedrockAPIKey
,
proxyURL
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 将 Bedrock 的 x-amzn-requestid 映射到 x-request-id,
// 使通用错误处理函数(handleErrorResponse、handleRetryExhaustedError)能正确提取 AWS request ID。
if
awsReqID
:=
resp
.
Header
.
Get
(
"x-amzn-requestid"
);
awsReqID
!=
""
&&
resp
.
Header
.
Get
(
"x-request-id"
)
==
""
{
resp
.
Header
.
Set
(
"x-request-id"
,
awsReqID
)
}
// 错误/failover 处理
if
resp
.
StatusCode
>=
400
{
return
s
.
handleBedrockUpstreamErrors
(
ctx
,
resp
,
c
,
account
)
}
// 响应处理
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
var
clientDisconnect
bool
if
reqStream
{
streamResult
,
err
:=
s
.
handleBedrockStreamingResponse
(
ctx
,
resp
,
c
,
account
,
startTime
,
reqModel
)
if
err
!=
nil
{
return
nil
,
err
}
usage
=
streamResult
.
usage
firstTokenMs
=
streamResult
.
firstTokenMs
clientDisconnect
=
streamResult
.
clientDisconnect
}
else
{
usage
,
err
=
s
.
handleBedrockNonStreamingResponse
(
ctx
,
resp
,
c
,
account
)
if
err
!=
nil
{
return
nil
,
err
}
}
if
usage
==
nil
{
usage
=
&
ClaudeUsage
{}
}
return
&
ForwardResult
{
RequestID
:
resp
.
Header
.
Get
(
"x-amzn-requestid"
),
Usage
:
*
usage
,
Model
:
reqModel
,
Stream
:
reqStream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
ClientDisconnect
:
clientDisconnect
,
},
nil
}
// executeBedrockUpstream 执行 Bedrock 上游请求(含重试逻辑)
func
(
s
*
GatewayService
)
executeBedrockUpstream
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
modelID
string
,
region
string
,
stream
bool
,
signer
*
BedrockSigner
,
apiKey
string
,
proxyURL
string
,
)
(
*
http
.
Response
,
error
)
{
var
resp
*
http
.
Response
var
err
error
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
var
upstreamReq
*
http
.
Request
if
account
.
IsBedrockAPIKey
()
{
upstreamReq
,
err
=
s
.
buildUpstreamRequestBedrockAPIKey
(
ctx
,
body
,
modelID
,
region
,
stream
,
apiKey
)
}
else
{
upstreamReq
,
err
=
s
.
buildUpstreamRequestBedrock
(
ctx
,
body
,
modelID
,
region
,
stream
,
signer
)
}
if
err
!=
nil
{
return
nil
,
err
}
resp
,
err
=
s
.
httpUpstream
.
DoWithTLS
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
false
)
if
err
!=
nil
{
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
_
=
resp
.
Body
.
Close
()
}
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
setOpsUpstreamError
(
c
,
0
,
safeErr
,
""
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream request failed"
,
},
})
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %s"
,
safeErr
)
}
if
resp
.
StatusCode
>=
400
&&
resp
.
StatusCode
!=
400
&&
s
.
shouldRetryUpstreamError
(
account
,
resp
.
StatusCode
)
{
if
attempt
<
maxRetryAttempts
{
elapsed
:=
time
.
Since
(
retryStart
)
if
elapsed
>=
maxRetryElapsed
{
break
}
delay
:=
retryBackoffDelay
(
attempt
)
remaining
:=
maxRetryElapsed
-
elapsed
if
delay
>
remaining
{
delay
=
remaining
}
if
delay
<=
0
{
break
}
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
Kind
:
"retry"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
Detail
:
func
()
string
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
return
truncateString
(
string
(
respBody
),
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
)
}
return
""
}(),
})
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Bedrock] account %d: upstream error %d, retry %d/%d after %v"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
maxRetryAttempts
,
delay
)
if
err
:=
sleepWithContext
(
ctx
,
delay
);
err
!=
nil
{
return
nil
,
err
}
continue
}
break
}
break
}
if
resp
==
nil
||
resp
.
Body
==
nil
{
return
nil
,
errors
.
New
(
"upstream request failed: empty response"
)
}
return
resp
,
nil
}
// handleBedrockUpstreamErrors 处理 Bedrock 上游 4xx/5xx 错误(failover + 错误响应)
func
(
s
*
GatewayService
)
handleBedrockUpstreamErrors
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
)
(
*
ForwardResult
,
error
)
{
// retry exhausted + failover
if
s
.
shouldRetryUpstreamError
(
account
,
resp
.
StatusCode
)
{
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
logger
.
LegacyPrintf
(
"service.gateway"
,
"[Bedrock] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d Body=%s"
,
account
.
ID
,
account
.
Name
,
resp
.
StatusCode
,
truncateString
(
string
(
respBody
),
1000
))
s
.
handleRetryExhaustedSideEffects
(
ctx
,
resp
,
account
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
Kind
:
"retry_exhausted_failover"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
,
ResponseBody
:
respBody
,
}
}
return
s
.
handleRetryExhaustedError
(
ctx
,
resp
,
c
,
account
)
}
// non-retryable failover
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
s
.
handleFailoverSideEffects
(
ctx
,
resp
,
account
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
Kind
:
"failover"
,
Message
:
extractUpstreamErrorMessage
(
respBody
),
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
,
ResponseBody
:
respBody
,
}
}
// other errors
return
s
.
handleErrorResponse
(
ctx
,
resp
,
c
,
account
)
}
// buildUpstreamRequestBedrock 构建 Bedrock 上游请求
func
(
s
*
GatewayService
)
buildUpstreamRequestBedrock
(
ctx
context
.
Context
,
body
[]
byte
,
modelID
string
,
region
string
,
stream
bool
,
signer
*
BedrockSigner
,
)
(
*
http
.
Request
,
error
)
{
targetURL
:=
BuildBedrockURL
(
region
,
modelID
,
stream
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
targetURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
err
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Accept"
,
"application/json"
)
// SigV4 签名
if
err
:=
signer
.
SignRequest
(
ctx
,
req
,
body
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"sign bedrock request: %w"
,
err
)
}
return
req
,
nil
}
// buildUpstreamRequestBedrockAPIKey 构建 Bedrock API Key (Bearer Token) 上游请求
func
(
s
*
GatewayService
)
buildUpstreamRequestBedrockAPIKey
(
ctx
context
.
Context
,
body
[]
byte
,
modelID
string
,
region
string
,
stream
bool
,
apiKey
string
,
)
(
*
http
.
Request
,
error
)
{
targetURL
:=
BuildBedrockURL
(
region
,
modelID
,
stream
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
targetURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
err
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Accept"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
return
req
,
nil
}
// handleBedrockNonStreamingResponse 处理 Bedrock 非流式响应
// Bedrock InvokeModel 非流式响应的 body 格式与 Claude API 兼容
func
(
s
*
GatewayService
)
handleBedrockNonStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
)
(
*
ClaudeUsage
,
error
)
{
maxBytes
:=
resolveUpstreamResponseReadLimit
(
s
.
cfg
)
body
,
err
:=
readUpstreamResponseBodyLimited
(
resp
.
Body
,
maxBytes
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUpstreamResponseBodyTooLarge
)
{
setOpsUpstreamError
(
c
,
http
.
StatusBadGateway
,
"upstream response too large"
,
""
)
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
"upstream_error"
,
"message"
:
"Upstream response too large"
,
},
})
}
return
nil
,
err
}
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
// 并移除该字段避免透传给客户端
body
=
transformBedrockInvocationMetrics
(
body
)
usage
:=
parseClaudeUsageFromResponseBody
(
body
)
c
.
Header
(
"Content-Type"
,
"application/json"
)
if
v
:=
resp
.
Header
.
Get
(
"x-amzn-requestid"
);
v
!=
""
{
c
.
Header
(
"x-request-id"
,
v
)
}
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
body
)
return
usage
,
nil
}
func
(
s
*
GatewayService
)
buildUpstreamRequest
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
,
token
,
tokenType
,
modelID
string
,
reqStream
bool
,
mimicClaudeCode
bool
)
(
*
http
.
Request
,
error
)
{
// 确定目标URL
targetURL
:=
claudeAPIURL
...
...
@@ -5481,6 +5908,76 @@ func containsBetaToken(header, token string) bool {
return
false
}
func
filterBetaTokens
(
tokens
[]
string
,
filterSet
map
[
string
]
struct
{})
[]
string
{
if
len
(
tokens
)
==
0
||
len
(
filterSet
)
==
0
{
return
tokens
}
kept
:=
make
([]
string
,
0
,
len
(
tokens
))
for
_
,
token
:=
range
tokens
{
if
_
,
filtered
:=
filterSet
[
token
];
!
filtered
{
kept
=
append
(
kept
,
token
)
}
}
return
kept
}
func
(
s
*
GatewayService
)
resolveBedrockBetaTokensForRequest
(
ctx
context
.
Context
,
account
*
Account
,
betaHeader
string
,
body
[]
byte
,
modelID
string
,
)
([]
string
,
error
)
{
// 1. 对原始 header 中的 beta token 做 block 检查(快速失败)
policy
:=
s
.
evaluateBetaPolicy
(
ctx
,
betaHeader
,
account
)
if
policy
.
blockErr
!=
nil
{
return
nil
,
policy
.
blockErr
}
// 2. 解析 header + body 自动注入 + Bedrock 转换/过滤
betaTokens
:=
ResolveBedrockBetaTokens
(
betaHeader
,
body
,
modelID
)
// 3. 对最终 token 列表再做 block 检查,捕获通过 body 自动注入绕过 header block 的情况。
// 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
// 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 →
// 如果不做此检查,block 规则会被绕过。
if
blockErr
:=
s
.
checkBetaPolicyBlockForTokens
(
ctx
,
betaTokens
,
account
);
blockErr
!=
nil
{
return
nil
,
blockErr
}
return
filterBetaTokens
(
betaTokens
,
policy
.
filterSet
),
nil
}
// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。
// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。
func
(
s
*
GatewayService
)
checkBetaPolicyBlockForTokens
(
ctx
context
.
Context
,
tokens
[]
string
,
account
*
Account
)
*
BetaBlockedError
{
if
s
.
settingService
==
nil
||
len
(
tokens
)
==
0
{
return
nil
}
settings
,
err
:=
s
.
settingService
.
GetBetaPolicySettings
(
ctx
)
if
err
!=
nil
||
settings
==
nil
{
return
nil
}
isOAuth
:=
account
.
IsOAuth
()
tokenSet
:=
buildBetaTokenSet
(
tokens
)
for
_
,
rule
:=
range
settings
.
Rules
{
if
rule
.
Action
!=
BetaPolicyActionBlock
{
continue
}
if
!
betaPolicyScopeMatches
(
rule
.
Scope
,
isOAuth
)
{
continue
}
if
_
,
present
:=
tokenSet
[
rule
.
BetaToken
];
present
{
msg
:=
rule
.
ErrorMessage
if
msg
==
""
{
msg
=
"beta feature "
+
rule
.
BetaToken
+
" is not allowed"
}
return
&
BetaBlockedError
{
Message
:
msg
}
}
}
return
nil
}
func
buildBetaTokenSet
(
tokens
[]
string
)
map
[
string
]
struct
{}
{
m
:=
make
(
map
[
string
]
struct
{},
len
(
tokens
))
for
_
,
t
:=
range
tokens
{
...
...
@@ -6027,6 +6524,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace
:=
originalModel
!=
mappedModel
clientDisconnected
:=
false
// 客户端断开标志,断开后继续读取上游以获取完整usage
sawTerminalEvent
:=
false
pendingEventLines
:=
make
([]
string
,
0
,
4
)
...
...
@@ -6057,6 +6555,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
if
dataLine
==
"[DONE]"
{
sawTerminalEvent
=
true
block
:=
""
if
eventName
!=
""
{
block
=
"event: "
+
eventName
+
"
\n
"
...
...
@@ -6123,6 +6622,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
usagePatch
:=
s
.
extractSSEUsagePatch
(
event
)
if
anthropicStreamEventIsTerminal
(
eventName
,
dataLine
)
{
sawTerminalEvent
=
true
}
if
!
eventChanged
{
block
:=
""
if
eventName
!=
""
{
...
...
@@ -6156,18 +6658,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
// 上游完成,返回结果
if
!
sawTerminalEvent
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
clientDisconnected
},
fmt
.
Errorf
(
"stream usage incomplete: missing terminal event"
)
}
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
clientDisconnected
},
nil
}
if
ev
.
err
!=
nil
{
if
sawTerminalEvent
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
clientDisconnected
},
nil
}
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
if
errors
.
Is
(
ev
.
err
,
context
.
Canceled
)
||
errors
.
Is
(
ev
.
err
,
context
.
DeadlineExceeded
)
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"Context canceled during streaming, returning collected usage"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
fmt
.
Errorf
(
"stream usage incomplete: %w"
,
ev
.
err
)
}
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
if
clientDisconnected
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"Upstream read error after client disconnect: %v, returning collected usage"
,
ev
.
err
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
fmt
.
Errorf
(
"stream usage incomplete after disconnect: %w"
,
ev
.
err
)
}
// 客户端未断开,正常的错误处理
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
...
...
@@ -6226,9 +6732,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
continue
}
if
clientDisconnected
{
// 客户端已断开,上游也超时了,返回已收集的 usage
logger
.
LegacyPrintf
(
"service.gateway"
,
"Upstream timeout after client disconnect, returning collected usage"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
fmt
.
Errorf
(
"stream usage incomplete after timeout"
)
}
logger
.
LegacyPrintf
(
"service.gateway"
,
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
// 处理流超时,可能标记账户为临时不可调度或错误状态
...
...
@@ -6590,15 +7094,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
Result
*
ForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
ForceCacheBilling
bool
// 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService
APIKeyQuotaUpdater
// 可选:用于更新API Key配额
Result
*
ForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
RequestPayloadHash
string
// 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
ForceCacheBilling
bool
// 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService
APIKeyQuotaUpdater
// 可选:用于更新API Key配额
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
...
...
@@ -6607,6 +7112,14 @@ type APIKeyQuotaUpdater interface {
UpdateRateLimitUsage
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
}
type
apiKeyAuthCacheInvalidator
interface
{
InvalidateAuthCacheByKey
(
ctx
context
.
Context
,
key
string
)
}
type
usageLogBestEffortWriter
interface
{
CreateBestEffort
(
ctx
context
.
Context
,
log
*
UsageLog
)
error
}
// postUsageBillingParams 统一扣费所需的参数
type
postUsageBillingParams
struct
{
Cost
*
CostBreakdown
...
...
@@ -6614,6 +7127,7 @@ type postUsageBillingParams struct {
APIKey
*
APIKey
Account
*
Account
Subscription
*
UserSubscription
RequestPayloadHash
string
IsSubscriptionBill
bool
AccountRateMultiplier
float64
APIKeyService
APIKeyQuotaUpdater
...
...
@@ -6625,19 +7139,22 @@ type postUsageBillingParams struct {
// - API Key 限速用量更新
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
func
postUsageBilling
(
ctx
context
.
Context
,
p
*
postUsageBillingParams
,
deps
*
billingDeps
)
{
billingCtx
,
cancel
:=
detachedBillingContext
(
ctx
)
defer
cancel
()
cost
:=
p
.
Cost
// 1. 订阅 / 余额扣费
if
p
.
IsSubscriptionBill
{
if
cost
.
TotalCost
>
0
{
if
err
:=
deps
.
userSubRepo
.
IncrementUsage
(
c
tx
,
p
.
Subscription
.
ID
,
cost
.
TotalCost
);
err
!=
nil
{
if
err
:=
deps
.
userSubRepo
.
IncrementUsage
(
billingC
tx
,
p
.
Subscription
.
ID
,
cost
.
TotalCost
);
err
!=
nil
{
slog
.
Error
(
"increment subscription usage failed"
,
"subscription_id"
,
p
.
Subscription
.
ID
,
"error"
,
err
)
}
deps
.
billingCacheService
.
QueueUpdateSubscriptionUsage
(
p
.
User
.
ID
,
*
p
.
APIKey
.
GroupID
,
cost
.
TotalCost
)
}
}
else
{
if
cost
.
ActualCost
>
0
{
if
err
:=
deps
.
userRepo
.
DeductBalance
(
c
tx
,
p
.
User
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
if
err
:=
deps
.
userRepo
.
DeductBalance
(
billingC
tx
,
p
.
User
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
slog
.
Error
(
"deduct balance failed"
,
"user_id"
,
p
.
User
.
ID
,
"error"
,
err
)
}
deps
.
billingCacheService
.
QueueDeductBalance
(
p
.
User
.
ID
,
cost
.
ActualCost
)
...
...
@@ -6646,31 +7163,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
// 2. API Key 配额
if
cost
.
ActualCost
>
0
&&
p
.
APIKey
.
Quota
>
0
&&
p
.
APIKeyService
!=
nil
{
if
err
:=
p
.
APIKeyService
.
UpdateQuotaUsed
(
c
tx
,
p
.
APIKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
if
err
:=
p
.
APIKeyService
.
UpdateQuotaUsed
(
billingC
tx
,
p
.
APIKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
slog
.
Error
(
"update api key quota failed"
,
"api_key_id"
,
p
.
APIKey
.
ID
,
"error"
,
err
)
}
}
// 3. API Key 限速用量
if
cost
.
ActualCost
>
0
&&
p
.
APIKey
.
HasRateLimits
()
&&
p
.
APIKeyService
!=
nil
{
if
err
:=
p
.
APIKeyService
.
UpdateRateLimitUsage
(
c
tx
,
p
.
APIKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
if
err
:=
p
.
APIKeyService
.
UpdateRateLimitUsage
(
billingC
tx
,
p
.
APIKey
.
ID
,
cost
.
ActualCost
);
err
!=
nil
{
slog
.
Error
(
"update api key rate limit usage failed"
,
"api_key_id"
,
p
.
APIKey
.
ID
,
"error"
,
err
)
}
deps
.
billingCacheService
.
QueueUpdateAPIKeyRateLimitUsage
(
p
.
APIKey
.
ID
,
cost
.
ActualCost
)
}
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if
cost
.
TotalCost
>
0
&&
p
.
Account
.
Type
==
AccountTypeAPIKey
&&
p
.
Account
.
HasAnyQuotaLimit
()
{
accountCost
:=
cost
.
TotalCost
*
p
.
AccountRateMultiplier
if
err
:=
deps
.
accountRepo
.
IncrementQuotaUsed
(
c
tx
,
p
.
Account
.
ID
,
accountCost
);
err
!=
nil
{
if
err
:=
deps
.
accountRepo
.
IncrementQuotaUsed
(
billingC
tx
,
p
.
Account
.
ID
,
accountCost
);
err
!=
nil
{
slog
.
Error
(
"increment account quota used failed"
,
"account_id"
,
p
.
Account
.
ID
,
"cost"
,
accountCost
,
"error"
,
err
)
}
}
// 5. 更新账号最近使用时间
finalizePostUsageBilling
(
p
,
deps
)
}
func
resolveUsageBillingRequestID
(
ctx
context
.
Context
,
upstreamRequestID
string
)
string
{
if
ctx
!=
nil
{
if
clientRequestID
,
_
:=
ctx
.
Value
(
ctxkey
.
ClientRequestID
)
.
(
string
);
strings
.
TrimSpace
(
clientRequestID
)
!=
""
{
return
"client:"
+
strings
.
TrimSpace
(
clientRequestID
)
}
if
requestID
,
_
:=
ctx
.
Value
(
ctxkey
.
RequestID
)
.
(
string
);
strings
.
TrimSpace
(
requestID
)
!=
""
{
return
"local:"
+
strings
.
TrimSpace
(
requestID
)
}
}
if
requestID
:=
strings
.
TrimSpace
(
upstreamRequestID
);
requestID
!=
""
{
return
requestID
}
return
"generated:"
+
generateRequestID
()
}
func
resolveUsageBillingPayloadFingerprint
(
ctx
context
.
Context
,
requestPayloadHash
string
)
string
{
if
payloadHash
:=
strings
.
TrimSpace
(
requestPayloadHash
);
payloadHash
!=
""
{
return
payloadHash
}
if
ctx
!=
nil
{
if
clientRequestID
,
_
:=
ctx
.
Value
(
ctxkey
.
ClientRequestID
)
.
(
string
);
strings
.
TrimSpace
(
clientRequestID
)
!=
""
{
return
"client:"
+
strings
.
TrimSpace
(
clientRequestID
)
}
if
requestID
,
_
:=
ctx
.
Value
(
ctxkey
.
RequestID
)
.
(
string
);
strings
.
TrimSpace
(
requestID
)
!=
""
{
return
"local:"
+
strings
.
TrimSpace
(
requestID
)
}
}
return
""
}
func
buildUsageBillingCommand
(
requestID
string
,
usageLog
*
UsageLog
,
p
*
postUsageBillingParams
)
*
UsageBillingCommand
{
if
p
==
nil
||
p
.
Cost
==
nil
||
p
.
APIKey
==
nil
||
p
.
User
==
nil
||
p
.
Account
==
nil
{
return
nil
}
cmd
:=
&
UsageBillingCommand
{
RequestID
:
requestID
,
APIKeyID
:
p
.
APIKey
.
ID
,
UserID
:
p
.
User
.
ID
,
AccountID
:
p
.
Account
.
ID
,
AccountType
:
p
.
Account
.
Type
,
RequestPayloadHash
:
strings
.
TrimSpace
(
p
.
RequestPayloadHash
),
}
if
usageLog
!=
nil
{
cmd
.
Model
=
usageLog
.
Model
cmd
.
BillingType
=
usageLog
.
BillingType
cmd
.
InputTokens
=
usageLog
.
InputTokens
cmd
.
OutputTokens
=
usageLog
.
OutputTokens
cmd
.
CacheCreationTokens
=
usageLog
.
CacheCreationTokens
cmd
.
CacheReadTokens
=
usageLog
.
CacheReadTokens
cmd
.
ImageCount
=
usageLog
.
ImageCount
if
usageLog
.
MediaType
!=
nil
{
cmd
.
MediaType
=
*
usageLog
.
MediaType
}
if
usageLog
.
ServiceTier
!=
nil
{
cmd
.
ServiceTier
=
*
usageLog
.
ServiceTier
}
if
usageLog
.
ReasoningEffort
!=
nil
{
cmd
.
ReasoningEffort
=
*
usageLog
.
ReasoningEffort
}
if
usageLog
.
SubscriptionID
!=
nil
{
cmd
.
SubscriptionID
=
usageLog
.
SubscriptionID
}
}
if
p
.
IsSubscriptionBill
&&
p
.
Subscription
!=
nil
&&
p
.
Cost
.
TotalCost
>
0
{
cmd
.
SubscriptionID
=
&
p
.
Subscription
.
ID
cmd
.
SubscriptionCost
=
p
.
Cost
.
TotalCost
}
else
if
p
.
Cost
.
ActualCost
>
0
{
cmd
.
BalanceCost
=
p
.
Cost
.
ActualCost
}
if
p
.
Cost
.
ActualCost
>
0
&&
p
.
APIKey
.
Quota
>
0
&&
p
.
APIKeyService
!=
nil
{
cmd
.
APIKeyQuotaCost
=
p
.
Cost
.
ActualCost
}
if
p
.
Cost
.
ActualCost
>
0
&&
p
.
APIKey
.
HasRateLimits
()
&&
p
.
APIKeyService
!=
nil
{
cmd
.
APIKeyRateLimitCost
=
p
.
Cost
.
ActualCost
}
if
p
.
Cost
.
TotalCost
>
0
&&
p
.
Account
.
Type
==
AccountTypeAPIKey
&&
p
.
Account
.
HasAnyQuotaLimit
()
{
cmd
.
AccountQuotaCost
=
p
.
Cost
.
TotalCost
*
p
.
AccountRateMultiplier
}
cmd
.
Normalize
()
return
cmd
}
func
applyUsageBilling
(
ctx
context
.
Context
,
requestID
string
,
usageLog
*
UsageLog
,
p
*
postUsageBillingParams
,
deps
*
billingDeps
,
repo
UsageBillingRepository
)
(
bool
,
error
)
{
if
p
==
nil
||
deps
==
nil
{
return
false
,
nil
}
cmd
:=
buildUsageBillingCommand
(
requestID
,
usageLog
,
p
)
if
cmd
==
nil
||
cmd
.
RequestID
==
""
||
repo
==
nil
{
postUsageBilling
(
ctx
,
p
,
deps
)
return
true
,
nil
}
billingCtx
,
cancel
:=
detachedBillingContext
(
ctx
)
defer
cancel
()
result
,
err
:=
repo
.
Apply
(
billingCtx
,
cmd
)
if
err
!=
nil
{
return
false
,
err
}
if
result
==
nil
||
!
result
.
Applied
{
deps
.
deferredService
.
ScheduleLastUsedUpdate
(
p
.
Account
.
ID
)
return
false
,
nil
}
if
result
.
APIKeyQuotaExhausted
{
if
invalidator
,
ok
:=
p
.
APIKeyService
.
(
apiKeyAuthCacheInvalidator
);
ok
&&
p
.
APIKey
!=
nil
&&
p
.
APIKey
.
Key
!=
""
{
invalidator
.
InvalidateAuthCacheByKey
(
billingCtx
,
p
.
APIKey
.
Key
)
}
}
finalizePostUsageBilling
(
p
,
deps
)
return
true
,
nil
}
func
finalizePostUsageBilling
(
p
*
postUsageBillingParams
,
deps
*
billingDeps
)
{
if
p
==
nil
||
p
.
Cost
==
nil
||
deps
==
nil
{
return
}
if
p
.
IsSubscriptionBill
{
if
p
.
Cost
.
TotalCost
>
0
&&
p
.
User
!=
nil
&&
p
.
APIKey
!=
nil
&&
p
.
APIKey
.
GroupID
!=
nil
{
deps
.
billingCacheService
.
QueueUpdateSubscriptionUsage
(
p
.
User
.
ID
,
*
p
.
APIKey
.
GroupID
,
p
.
Cost
.
TotalCost
)
}
}
else
if
p
.
Cost
.
ActualCost
>
0
&&
p
.
User
!=
nil
{
deps
.
billingCacheService
.
QueueDeductBalance
(
p
.
User
.
ID
,
p
.
Cost
.
ActualCost
)
}
if
p
.
Cost
.
ActualCost
>
0
&&
p
.
APIKey
!=
nil
&&
p
.
APIKey
.
HasRateLimits
()
{
deps
.
billingCacheService
.
QueueUpdateAPIKeyRateLimitUsage
(
p
.
APIKey
.
ID
,
p
.
Cost
.
ActualCost
)
}
deps
.
deferredService
.
ScheduleLastUsedUpdate
(
p
.
Account
.
ID
)
}
func
detachedBillingContext
(
ctx
context
.
Context
)
(
context
.
Context
,
context
.
CancelFunc
)
{
base
:=
context
.
Background
()
if
ctx
!=
nil
{
base
=
context
.
WithoutCancel
(
ctx
)
}
return
context
.
WithTimeout
(
base
,
postUsageBillingTimeout
)
}
func
detachStreamUpstreamContext
(
ctx
context
.
Context
,
stream
bool
)
(
context
.
Context
,
context
.
CancelFunc
)
{
if
!
stream
{
return
ctx
,
func
()
{}
}
if
ctx
==
nil
{
return
context
.
Background
(),
func
()
{}
}
return
context
.
WithoutCancel
(
ctx
),
func
()
{}
}
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type
billingDeps
struct
{
accountRepo
AccountRepository
...
...
@@ -6690,6 +7363,31 @@ func (s *GatewayService) billingDeps() *billingDeps {
}
}
func
writeUsageLogBestEffort
(
ctx
context
.
Context
,
repo
UsageLogRepository
,
usageLog
*
UsageLog
,
logKey
string
)
{
if
repo
==
nil
||
usageLog
==
nil
{
return
}
usageCtx
,
cancel
:=
detachedBillingContext
(
ctx
)
defer
cancel
()
if
writer
,
ok
:=
repo
.
(
usageLogBestEffortWriter
);
ok
{
if
err
:=
writer
.
CreateBestEffort
(
usageCtx
,
usageLog
);
err
!=
nil
{
logger
.
LegacyPrintf
(
logKey
,
"Create usage log failed: %v"
,
err
)
if
IsUsageLogCreateDropped
(
err
)
{
return
}
if
_
,
syncErr
:=
repo
.
Create
(
usageCtx
,
usageLog
);
syncErr
!=
nil
{
logger
.
LegacyPrintf
(
logKey
,
"Create usage log sync fallback failed: %v"
,
syncErr
)
}
}
return
}
if
_
,
err
:=
repo
.
Create
(
usageCtx
,
usageLog
);
err
!=
nil
{
logger
.
LegacyPrintf
(
logKey
,
"Create usage log failed: %v"
,
err
)
}
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func
(
s
*
GatewayService
)
RecordUsage
(
ctx
context
.
Context
,
input
*
RecordUsageInput
)
error
{
result
:=
input
.
Result
...
...
@@ -6791,11 +7489,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
mediaType
=
&
result
.
MediaType
}
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
requestID
:=
resolveUsageBillingRequestID
(
ctx
,
result
.
RequestID
)
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
re
sult
.
Re
questID
,
RequestID
:
requestID
,
Model
:
result
.
Model
,
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
...
...
@@ -6840,33 +7539,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
inserted
,
err
:=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"Create usage log failed: %v"
,
err
)
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.gateway"
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
shouldBill
:=
inserted
||
err
!=
nil
if
shouldBill
{
postUsageBilling
(
ctx
,
&
postUsageBillingParams
{
billingErr
:=
func
()
error
{
_
,
err
:=
applyUsageBilling
(
ctx
,
requestID
,
usageLog
,
&
postUsageBillingParams
{
Cost
:
cost
,
User
:
user
,
APIKey
:
apiKey
,
Account
:
account
,
Subscription
:
subscription
,
RequestPayloadHash
:
resolveUsageBillingPayloadFingerprint
(
ctx
,
input
.
RequestPayloadHash
),
IsSubscriptionBill
:
isSubscriptionBilling
,
AccountRateMultiplier
:
accountRateMultiplier
,
APIKeyService
:
input
.
APIKeyService
,
},
s
.
billingDeps
())
}
else
{
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
},
s
.
billingDeps
(),
s
.
usageBillingRepo
)
return
err
}()
if
billingErr
!=
nil
{
return
billingErr
}
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.gateway"
)
return
nil
}
...
...
@@ -6877,13 +7575,14 @@ type RecordUsageLongContextInput struct {
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
LongContextThreshold
int
// 长上下文阈值(如 200000)
LongContextMultiplier
float64
// 超出阈值部分的倍率(如 2.0)
ForceCacheBilling
bool
// 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService
*
APIKeyService
// API Key 配额服务(可选)
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
RequestPayloadHash
string
// 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
LongContextThreshold
int
// 长上下文阈值(如 200000)
LongContextMultiplier
float64
// 超出阈值部分的倍率(如 2.0)
ForceCacheBilling
bool
// 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService
APIKeyQuotaUpdater
// API Key 配额服务(可选)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
...
...
@@ -6966,11 +7665,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
imageSize
=
&
result
.
ImageSize
}
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
requestID
:=
resolveUsageBillingRequestID
(
ctx
,
result
.
RequestID
)
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
re
sult
.
Re
questID
,
RequestID
:
requestID
,
Model
:
result
.
Model
,
InputTokens
:
result
.
Usage
.
InputTokens
,
OutputTokens
:
result
.
Usage
.
OutputTokens
,
...
...
@@ -7014,33 +7714,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
inserted
,
err
:=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.gateway"
,
"Create usage log failed: %v"
,
err
)
}
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.gateway"
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
shouldBill
:=
inserted
||
err
!=
nil
if
shouldBill
{
postUsageBilling
(
ctx
,
&
postUsageBillingParams
{
billingErr
:=
func
()
error
{
_
,
err
:=
applyUsageBilling
(
ctx
,
requestID
,
usageLog
,
&
postUsageBillingParams
{
Cost
:
cost
,
User
:
user
,
APIKey
:
apiKey
,
Account
:
account
,
Subscription
:
subscription
,
RequestPayloadHash
:
resolveUsageBillingPayloadFingerprint
(
ctx
,
input
.
RequestPayloadHash
),
IsSubscriptionBill
:
isSubscriptionBilling
,
AccountRateMultiplier
:
accountRateMultiplier
,
APIKeyService
:
input
.
APIKeyService
,
},
s
.
billingDeps
())
}
else
{
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
},
s
.
billingDeps
(),
s
.
usageBillingRepo
)
return
err
}()
if
billingErr
!=
nil
{
return
billingErr
}
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.gateway"
)
return
nil
}
...
...
@@ -7064,6 +7763,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return
s
.
forwardCountTokensAnthropicAPIKeyPassthrough
(
ctx
,
c
,
account
,
passthroughBody
)
}
// Bedrock 不支持 count_tokens 端点
if
account
!=
nil
&&
account
.
IsBedrock
()
{
s
.
countTokensError
(
c
,
http
.
StatusNotFound
,
"not_found_error"
,
"count_tokens endpoint is not supported for Bedrock"
)
return
nil
}
body
:=
parsed
.
Body
reqModel
:=
parsed
.
Model
...
...
backend/internal/service/gateway_service_bedrock_beta_test.go
0 → 100644
View file @
01ef7340
package
service
import
(
"context"
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type
betaPolicySettingRepoStub
struct
{
values
map
[
string
]
string
}
func
(
s
*
betaPolicySettingRepoStub
)
Get
(
ctx
context
.
Context
,
key
string
)
(
*
Setting
,
error
)
{
panic
(
"unexpected Get call"
)
}
func
(
s
*
betaPolicySettingRepoStub
)
GetValue
(
ctx
context
.
Context
,
key
string
)
(
string
,
error
)
{
if
v
,
ok
:=
s
.
values
[
key
];
ok
{
return
v
,
nil
}
return
""
,
ErrSettingNotFound
}
func
(
s
*
betaPolicySettingRepoStub
)
Set
(
ctx
context
.
Context
,
key
,
value
string
)
error
{
panic
(
"unexpected Set call"
)
}
func
(
s
*
betaPolicySettingRepoStub
)
GetMultiple
(
ctx
context
.
Context
,
keys
[]
string
)
(
map
[
string
]
string
,
error
)
{
panic
(
"unexpected GetMultiple call"
)
}
func
(
s
*
betaPolicySettingRepoStub
)
SetMultiple
(
ctx
context
.
Context
,
settings
map
[
string
]
string
)
error
{
panic
(
"unexpected SetMultiple call"
)
}
func
(
s
*
betaPolicySettingRepoStub
)
GetAll
(
ctx
context
.
Context
)
(
map
[
string
]
string
,
error
)
{
panic
(
"unexpected GetAll call"
)
}
func
(
s
*
betaPolicySettingRepoStub
)
Delete
(
ctx
context
.
Context
,
key
string
)
error
{
panic
(
"unexpected Delete call"
)
}
func
TestResolveBedrockBetaTokensForRequest_BlocksOnOriginalAnthropicToken
(
t
*
testing
.
T
)
{
settings
:=
&
BetaPolicySettings
{
Rules
:
[]
BetaPolicyRule
{
{
BetaToken
:
"advanced-tool-use-2025-11-20"
,
Action
:
BetaPolicyActionBlock
,
Scope
:
BetaPolicyScopeAll
,
ErrorMessage
:
"advanced tool use is blocked"
,
},
},
}
raw
,
err
:=
json
.
Marshal
(
settings
)
if
err
!=
nil
{
t
.
Fatalf
(
"marshal settings: %v"
,
err
)
}
svc
:=
&
GatewayService
{
settingService
:
NewSettingService
(
&
betaPolicySettingRepoStub
{
values
:
map
[
string
]
string
{
SettingKeyBetaPolicySettings
:
string
(
raw
),
}},
&
config
.
Config
{},
),
}
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeBedrock
}
_
,
err
=
svc
.
resolveBedrockBetaTokensForRequest
(
context
.
Background
(),
account
,
"advanced-tool-use-2025-11-20"
,
[]
byte
(
`{"messages":[{"role":"user","content":"hi"}]}`
),
"us.anthropic.claude-opus-4-6-v1"
,
)
if
err
==
nil
{
t
.
Fatal
(
"expected raw advanced-tool-use token to be blocked before Bedrock transform"
)
}
if
err
.
Error
()
!=
"advanced tool use is blocked"
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
}
func
TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform
(
t
*
testing
.
T
)
{
settings
:=
&
BetaPolicySettings
{
Rules
:
[]
BetaPolicyRule
{
{
BetaToken
:
"tool-search-tool-2025-10-19"
,
Action
:
BetaPolicyActionFilter
,
Scope
:
BetaPolicyScopeAll
,
},
},
}
raw
,
err
:=
json
.
Marshal
(
settings
)
if
err
!=
nil
{
t
.
Fatalf
(
"marshal settings: %v"
,
err
)
}
svc
:=
&
GatewayService
{
settingService
:
NewSettingService
(
&
betaPolicySettingRepoStub
{
values
:
map
[
string
]
string
{
SettingKeyBetaPolicySettings
:
string
(
raw
),
}},
&
config
.
Config
{},
),
}
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeBedrock
}
betaTokens
,
err
:=
svc
.
resolveBedrockBetaTokensForRequest
(
context
.
Background
(),
account
,
"advanced-tool-use-2025-11-20"
,
[]
byte
(
`{"messages":[{"role":"user","content":"hi"}]}`
),
"us.anthropic.claude-opus-4-6-v1"
,
)
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
for
_
,
token
:=
range
betaTokens
{
if
token
==
"tool-search-tool-2025-10-19"
{
t
.
Fatalf
(
"expected transformed Bedrock token to be filtered"
)
}
}
}
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证:
// 管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
// 但请求体包含 thinking 字段 → 自动注入后应被 block。
func
TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking
(
t
*
testing
.
T
)
{
settings
:=
&
BetaPolicySettings
{
Rules
:
[]
BetaPolicyRule
{
{
BetaToken
:
"interleaved-thinking-2025-05-14"
,
Action
:
BetaPolicyActionBlock
,
Scope
:
BetaPolicyScopeAll
,
ErrorMessage
:
"thinking is blocked"
,
},
},
}
raw
,
err
:=
json
.
Marshal
(
settings
)
if
err
!=
nil
{
t
.
Fatalf
(
"marshal settings: %v"
,
err
)
}
svc
:=
&
GatewayService
{
settingService
:
NewSettingService
(
&
betaPolicySettingRepoStub
{
values
:
map
[
string
]
string
{
SettingKeyBetaPolicySettings
:
string
(
raw
),
}},
&
config
.
Config
{},
),
}
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeBedrock
}
// header 中不带 beta token,但 body 中有 thinking 字段
_
,
err
=
svc
.
resolveBedrockBetaTokensForRequest
(
context
.
Background
(),
account
,
""
,
// 空 header
[]
byte
(
`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`
),
"us.anthropic.claude-opus-4-6-v1"
,
)
if
err
==
nil
{
t
.
Fatal
(
"expected body-injected interleaved-thinking to be blocked"
)
}
if
err
.
Error
()
!=
"thinking is blocked"
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
}
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch 验证:
// 管理员 block 了 tool-search-tool,客户端不在 header 中带 beta token,
// 但请求体包含 tool search 工具 → 自动注入后应被 block。
func
TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch
(
t
*
testing
.
T
)
{
settings
:=
&
BetaPolicySettings
{
Rules
:
[]
BetaPolicyRule
{
{
BetaToken
:
"tool-search-tool-2025-10-19"
,
Action
:
BetaPolicyActionBlock
,
Scope
:
BetaPolicyScopeAll
,
ErrorMessage
:
"tool search is blocked"
,
},
},
}
raw
,
err
:=
json
.
Marshal
(
settings
)
if
err
!=
nil
{
t
.
Fatalf
(
"marshal settings: %v"
,
err
)
}
svc
:=
&
GatewayService
{
settingService
:
NewSettingService
(
&
betaPolicySettingRepoStub
{
values
:
map
[
string
]
string
{
SettingKeyBetaPolicySettings
:
string
(
raw
),
}},
&
config
.
Config
{},
),
}
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeBedrock
}
// header 中不带 beta token,但 body 中有 tool_search_tool 工具
_
,
err
=
svc
.
resolveBedrockBetaTokensForRequest
(
context
.
Background
(),
account
,
""
,
[]
byte
(
`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`
),
"us.anthropic.claude-sonnet-4-6"
,
)
if
err
==
nil
{
t
.
Fatal
(
"expected body-injected tool-search-tool to be blocked"
)
}
if
err
.
Error
()
!=
"tool search is blocked"
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
}
// TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches 验证:
// body 自动注入的 token 如果没有对应的 block 规则,应正常通过。
func
TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches
(
t
*
testing
.
T
)
{
settings
:=
&
BetaPolicySettings
{
Rules
:
[]
BetaPolicyRule
{
{
BetaToken
:
"computer-use-2025-11-24"
,
Action
:
BetaPolicyActionBlock
,
Scope
:
BetaPolicyScopeAll
,
ErrorMessage
:
"computer use is blocked"
,
},
},
}
raw
,
err
:=
json
.
Marshal
(
settings
)
if
err
!=
nil
{
t
.
Fatalf
(
"marshal settings: %v"
,
err
)
}
svc
:=
&
GatewayService
{
settingService
:
NewSettingService
(
&
betaPolicySettingRepoStub
{
values
:
map
[
string
]
string
{
SettingKeyBetaPolicySettings
:
string
(
raw
),
}},
&
config
.
Config
{},
),
}
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeBedrock
}
// body 中有 thinking(会注入 interleaved-thinking),但 block 规则只针对 computer-use
tokens
,
err
:=
svc
.
resolveBedrockBetaTokensForRequest
(
context
.
Background
(),
account
,
""
,
[]
byte
(
`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`
),
"us.anthropic.claude-opus-4-6-v1"
,
)
if
err
!=
nil
{
t
.
Fatalf
(
"unexpected error: %v"
,
err
)
}
found
:=
false
for
_
,
token
:=
range
tokens
{
if
token
==
"interleaved-thinking-2025-05-14"
{
found
=
true
}
}
if
!
found
{
t
.
Fatal
(
"expected interleaved-thinking token to be present"
)
}
}
backend/internal/service/gateway_service_bedrock_model_support_test.go
0 → 100644
View file @
01ef7340
package
service
import
"testing"
func
TestGatewayServiceIsModelSupportedByAccount_BedrockDefaultMappingRestrictsModels
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeBedrock
,
Credentials
:
map
[
string
]
any
{
"aws_region"
:
"us-east-1"
,
},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"claude-sonnet-4-5"
)
{
t
.
Fatalf
(
"expected default Bedrock alias to be supported"
)
}
if
svc
.
isModelSupportedByAccount
(
account
,
"claude-3-5-sonnet-20241022"
)
{
t
.
Fatalf
(
"expected unsupported alias to be rejected for Bedrock account"
)
}
}
func
TestGatewayServiceIsModelSupportedByAccount_BedrockCustomMappingStillActsAsAllowlist
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeBedrock
,
Credentials
:
map
[
string
]
any
{
"aws_region"
:
"eu-west-1"
,
"model_mapping"
:
map
[
string
]
any
{
"claude-sonnet-*"
:
"claude-sonnet-4-6"
,
},
},
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"claude-sonnet-4-6"
)
{
t
.
Fatalf
(
"expected matched custom mapping to be supported"
)
}
if
!
svc
.
isModelSupportedByAccount
(
account
,
"claude-opus-4-6"
)
{
t
.
Fatalf
(
"expected default Bedrock alias fallback to remain supported"
)
}
if
svc
.
isModelSupportedByAccount
(
account
,
"claude-3-5-sonnet-20241022"
)
{
t
.
Fatalf
(
"expected unsupported model to still be rejected"
)
}
}
backend/internal/service/gateway_streaming_test.go
View file @
01ef7340
...
...
@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
result
,
err
:=
svc
.
handleStreamingResponse
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
,
false
)
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"missing terminal event"
)
require
.
NotNil
(
t
,
result
)
}
...
...
backend/internal/service/openai_codex_transform.go
View file @
01ef7340
...
...
@@ -129,6 +129,41 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
}
}
// 兼容遗留的 functions 和 function_call,转换为 tools 和 tool_choice
if
functionsRaw
,
ok
:=
reqBody
[
"functions"
];
ok
{
if
functions
,
k
:=
functionsRaw
.
([]
any
);
k
{
tools
:=
make
([]
any
,
0
,
len
(
functions
))
for
_
,
f
:=
range
functions
{
tools
=
append
(
tools
,
map
[
string
]
any
{
"type"
:
"function"
,
"function"
:
f
,
})
}
reqBody
[
"tools"
]
=
tools
}
delete
(
reqBody
,
"functions"
)
result
.
Modified
=
true
}
if
fcRaw
,
ok
:=
reqBody
[
"function_call"
];
ok
{
if
fcStr
,
ok
:=
fcRaw
.
(
string
);
ok
{
// e.g. "auto", "none"
reqBody
[
"tool_choice"
]
=
fcStr
}
else
if
fcObj
,
ok
:=
fcRaw
.
(
map
[
string
]
any
);
ok
{
// e.g. {"name": "my_func"}
if
name
,
ok
:=
fcObj
[
"name"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
name
)
!=
""
{
reqBody
[
"tool_choice"
]
=
map
[
string
]
any
{
"type"
:
"function"
,
"function"
:
map
[
string
]
any
{
"name"
:
name
,
},
}
}
}
delete
(
reqBody
,
"function_call"
)
result
.
Modified
=
true
}
if
normalizeCodexTools
(
reqBody
)
{
result
.
Modified
=
true
}
...
...
@@ -303,6 +338,18 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
continue
}
typ
,
_
:=
m
[
"type"
]
.
(
string
)
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'"
fixIDPrefix
:=
func
(
id
string
)
string
{
if
id
==
""
||
strings
.
HasPrefix
(
id
,
"fc"
)
{
return
id
}
if
strings
.
HasPrefix
(
id
,
"call_"
)
{
return
"fc"
+
strings
.
TrimPrefix
(
id
,
"call_"
)
}
return
"fc_"
+
id
}
if
typ
==
"item_reference"
{
if
!
preserveReferences
{
continue
...
...
@@ -311,6 +358,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
for
key
,
value
:=
range
m
{
newItem
[
key
]
=
value
}
if
id
,
ok
:=
newItem
[
"id"
]
.
(
string
);
ok
&&
id
!=
""
{
newItem
[
"id"
]
=
fixIDPrefix
(
id
)
}
filtered
=
append
(
filtered
,
newItem
)
continue
}
...
...
@@ -330,10 +380,20 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
if
isCodexToolCallItemType
(
typ
)
{
if
callID
,
ok
:=
m
[
"call_id"
]
.
(
string
);
!
ok
||
strings
.
TrimSpace
(
callID
)
==
""
{
callID
,
ok
:=
m
[
"call_id"
]
.
(
string
)
if
!
ok
||
strings
.
TrimSpace
(
callID
)
==
""
{
if
id
,
ok
:=
m
[
"id"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
id
)
!=
""
{
callID
=
id
ensureCopy
()
newItem
[
"call_id"
]
=
callID
}
}
if
callID
!=
""
{
fixedCallID
:=
fixIDPrefix
(
callID
)
if
fixedCallID
!=
callID
{
ensureCopy
()
newItem
[
"call_id"
]
=
id
newItem
[
"call_id"
]
=
fixedCallID
}
}
}
...
...
@@ -344,6 +404,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
if
!
isCodexToolCallItemType
(
typ
)
{
delete
(
newItem
,
"call_id"
)
}
}
else
{
if
id
,
ok
:=
newItem
[
"id"
]
.
(
string
);
ok
&&
id
!=
""
{
fixedID
:=
fixIDPrefix
(
id
)
if
fixedID
!=
id
{
ensureCopy
()
newItem
[
"id"
]
=
fixedID
}
}
}
filtered
=
append
(
filtered
,
newItem
)
...
...
backend/internal/service/openai_codex_transform_test.go
View file @
01ef7340
...
...
@@ -33,12 +33,12 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
first
,
ok
:=
input
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"item_reference"
,
first
[
"type"
])
require
.
Equal
(
t
,
"ref1"
,
first
[
"id"
])
require
.
Equal
(
t
,
"
fc_
ref1"
,
first
[
"id"
])
// 校验 input[1] 为 map,确保后续字段断言安全。
second
,
ok
:=
input
[
1
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"o1"
,
second
[
"id"
])
require
.
Equal
(
t
,
"
fc_
o1"
,
second
[
"id"
])
}
func
TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved
(
t
*
testing
.
T
)
{
...
...
backend/internal/service/openai_gateway_record_usage_test.go
View file @
01ef7340
...
...
@@ -3,39 +3,68 @@ package service
import
(
"context"
"errors"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
type
openAIRecordUsageLogRepoStub
struct
{
UsageLogRepository
inserted
bool
err
error
calls
int
lastLog
*
UsageLog
inserted
bool
err
error
calls
int
lastLog
*
UsageLog
lastCtxErr
error
}
func
(
s
*
openAIRecordUsageLogRepoStub
)
Create
(
ctx
context
.
Context
,
log
*
UsageLog
)
(
bool
,
error
)
{
s
.
calls
++
s
.
lastLog
=
log
s
.
lastCtxErr
=
ctx
.
Err
()
return
s
.
inserted
,
s
.
err
}
type
openAIRecordUsageBillingRepoStub
struct
{
UsageBillingRepository
result
*
UsageBillingApplyResult
err
error
calls
int
lastCmd
*
UsageBillingCommand
lastCtxErr
error
}
func
(
s
*
openAIRecordUsageBillingRepoStub
)
Apply
(
ctx
context
.
Context
,
cmd
*
UsageBillingCommand
)
(
*
UsageBillingApplyResult
,
error
)
{
s
.
calls
++
s
.
lastCmd
=
cmd
s
.
lastCtxErr
=
ctx
.
Err
()
if
s
.
err
!=
nil
{
return
nil
,
s
.
err
}
if
s
.
result
!=
nil
{
return
s
.
result
,
nil
}
return
&
UsageBillingApplyResult
{
Applied
:
true
},
nil
}
type
openAIRecordUsageUserRepoStub
struct
{
UserRepository
deductCalls
int
deductErr
error
lastAmount
float64
lastCtxErr
error
}
func
(
s
*
openAIRecordUsageUserRepoStub
)
DeductBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
s
.
deductCalls
++
s
.
lastAmount
=
amount
s
.
lastCtxErr
=
ctx
.
Err
()
return
s
.
deductErr
}
...
...
@@ -44,29 +73,35 @@ type openAIRecordUsageSubRepoStub struct {
incrementCalls
int
incrementErr
error
lastCtxErr
error
}
func
(
s
*
openAIRecordUsageSubRepoStub
)
IncrementUsage
(
ctx
context
.
Context
,
id
int64
,
costUSD
float64
)
error
{
s
.
incrementCalls
++
s
.
lastCtxErr
=
ctx
.
Err
()
return
s
.
incrementErr
}
type
openAIRecordUsageAPIKeyQuotaStub
struct
{
quotaCalls
int
rateLimitCalls
int
err
error
lastAmount
float64
quotaCalls
int
rateLimitCalls
int
err
error
lastAmount
float64
lastQuotaCtxErr
error
lastRateLimitCtxErr
error
}
func
(
s
*
openAIRecordUsageAPIKeyQuotaStub
)
UpdateQuotaUsed
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
{
s
.
quotaCalls
++
s
.
lastAmount
=
cost
s
.
lastQuotaCtxErr
=
ctx
.
Err
()
return
s
.
err
}
func
(
s
*
openAIRecordUsageAPIKeyQuotaStub
)
UpdateRateLimitUsage
(
ctx
context
.
Context
,
apiKeyID
int64
,
cost
float64
)
error
{
s
.
rateLimitCalls
++
s
.
lastAmount
=
cost
s
.
lastRateLimitCtxErr
=
ctx
.
Err
()
return
s
.
err
}
...
...
@@ -93,23 +128,38 @@ func i64p(v int64) *int64 {
func
newOpenAIRecordUsageServiceForTest
(
usageRepo
UsageLogRepository
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
,
rateRepo
UserGroupRateRepository
)
*
OpenAIGatewayService
{
cfg
:=
&
config
.
Config
{}
cfg
.
Default
.
RateMultiplier
=
1.1
svc
:=
NewOpenAIGatewayService
(
nil
,
usageRepo
,
nil
,
userRepo
,
subRepo
,
rateRepo
,
nil
,
cfg
,
nil
,
nil
,
NewBillingService
(
cfg
,
nil
),
nil
,
&
BillingCacheService
{},
nil
,
&
DeferredService
{},
nil
,
)
svc
.
userGroupRateResolver
=
newUserGroupRateResolver
(
rateRepo
,
nil
,
resolveUserGroupRateCacheTTL
(
cfg
),
nil
,
"service.openai_gateway.test"
,
)
return
svc
}
return
&
OpenAIGatewayService
{
usageLogRepo
:
usageRepo
,
userRepo
:
userRepo
,
userSubRepo
:
subRepo
,
cfg
:
cfg
,
billingService
:
NewBillingService
(
cfg
,
nil
),
billingCacheService
:
&
BillingCacheService
{},
deferredService
:
&
DeferredService
{},
userGroupRateResolver
:
newUserGroupRateResolver
(
rateRepo
,
nil
,
resolveUserGroupRateCacheTTL
(
cfg
),
nil
,
"service.openai_gateway.test"
,
),
}
func
newOpenAIRecordUsageServiceWithBillingRepoForTest
(
usageRepo
UsageLogRepository
,
billingRepo
UsageBillingRepository
,
userRepo
UserRepository
,
subRepo
UserSubscriptionRepository
,
rateRepo
UserGroupRateRepository
)
*
OpenAIGatewayService
{
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
rateRepo
)
svc
.
usageBillingRepo
=
billingRepo
return
svc
}
func
expectedOpenAICost
(
t
*
testing
.
T
,
svc
*
OpenAIGatewayService
,
model
string
,
usage
OpenAIUsage
,
multiplier
float64
)
*
CostBreakdown
{
...
...
@@ -252,9 +302,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver
func
TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
false
}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
false
}}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageService
ForTest
(
usage
Repo
,
userRepo
,
subRepo
,
nil
)
svc
:=
newOpenAIRecordUsageService
WithBillingRepoForTest
(
usageRepo
,
billing
Repo
,
userRepo
,
subRepo
,
nil
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
...
...
@@ -272,9 +323,311 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
billingRepo
.
calls
)
require
.
Equal
(
t
,
1
,
usageRepo
.
calls
)
require
.
Equal
(
t
,
0
,
userRepo
.
deductCalls
)
require
.
Equal
(
t
,
0
,
subRepo
.
incrementCalls
)
}
func
TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
false
}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
false
}}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
quotaSvc
:=
&
openAIRecordUsageAPIKeyQuotaStub
{}
svc
:=
newOpenAIRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
userRepo
,
subRepo
,
nil
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_duplicate_billing_key"
,
Usage
:
OpenAIUsage
{
InputTokens
:
8
,
OutputTokens
:
4
,
},
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10045
,
Quota
:
100
,
},
User
:
&
User
{
ID
:
20045
},
Account
:
&
Account
{
ID
:
30045
},
APIKeyService
:
quotaSvc
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
billingRepo
.
calls
)
require
.
Equal
(
t
,
1
,
usageRepo
.
calls
)
require
.
Equal
(
t
,
0
,
userRepo
.
deductCalls
)
require
.
Equal
(
t
,
0
,
subRepo
.
incrementCalls
)
require
.
Equal
(
t
,
0
,
quotaSvc
.
quotaCalls
)
}
func
TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError
(
t
*
testing
.
T
)
{
usage
:=
OpenAIUsage
{
InputTokens
:
8
,
OutputTokens
:
4
}
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
false
,
err
:
errors
.
New
(
"usage log batch state uncertain"
)}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_usage_log_error"
,
Usage
:
usage
,
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10041
},
User
:
&
User
{
ID
:
20041
},
Account
:
&
Account
{
ID
:
30041
},
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
usageRepo
.
calls
)
require
.
Equal
(
t
,
1
,
userRepo
.
deductCalls
)
require
.
Equal
(
t
,
0
,
subRepo
.
incrementCalls
)
}
func
TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
false
,
err
:
MarkUsageLogCreateNotPersisted
(
context
.
Canceled
)}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
quotaSvc
:=
&
openAIRecordUsageAPIKeyQuotaStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_not_persisted"
,
Usage
:
OpenAIUsage
{
InputTokens
:
8
,
OutputTokens
:
4
,
},
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10043
,
Quota
:
100
,
},
User
:
&
User
{
ID
:
20043
},
Account
:
&
Account
{
ID
:
30043
},
APIKeyService
:
quotaSvc
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
usageRepo
.
calls
)
require
.
Equal
(
t
,
1
,
userRepo
.
deductCalls
)
require
.
Equal
(
t
,
0
,
subRepo
.
incrementCalls
)
require
.
Equal
(
t
,
1
,
quotaSvc
.
quotaCalls
)
}
func
TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext
(
t
*
testing
.
T
)
{
usage
:=
OpenAIUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
CacheReadInputTokens
:
2
}
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
false
,
err
:
context
.
DeadlineExceeded
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
quotaSvc
:=
&
openAIRecordUsageAPIKeyQuotaStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
reqCtx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
err
:=
svc
.
RecordUsage
(
reqCtx
,
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_detached_billing_ctx"
,
Usage
:
usage
,
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10042
,
Quota
:
100
,
},
User
:
&
User
{
ID
:
20042
},
Account
:
&
Account
{
ID
:
30042
},
APIKeyService
:
quotaSvc
,
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
userRepo
.
deductCalls
)
require
.
NoError
(
t
,
userRepo
.
lastCtxErr
)
require
.
Equal
(
t
,
1
,
quotaSvc
.
quotaCalls
)
require
.
NoError
(
t
,
quotaSvc
.
lastQuotaCtxErr
)
}
func
TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
userRepo
,
subRepo
,
nil
)
reqCtx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
err
:=
svc
.
RecordUsage
(
reqCtx
,
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_detached_billing_repo_ctx"
,
Usage
:
OpenAIUsage
{
InputTokens
:
8
,
OutputTokens
:
4
,
},
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10046
},
User
:
&
User
{
ID
:
20046
},
Account
:
&
Account
{
ID
:
30046
},
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
billingRepo
.
calls
)
require
.
NoError
(
t
,
billingRepo
.
lastCtxErr
)
require
.
Equal
(
t
,
1
,
usageRepo
.
calls
)
require
.
NoError
(
t
,
usageRepo
.
lastCtxErr
)
}
func
TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
svc
:=
newOpenAIRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
&
openAIRecordUsageUserRepoStub
{},
&
openAIRecordUsageSubRepoStub
{},
nil
)
payloadHash
:=
HashUsageRequestPayload
([]
byte
(
`{"model":"gpt-5","input":"hello"}`
))
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"openai_payload_hash"
,
Usage
:
OpenAIUsage
{
InputTokens
:
10
,
OutputTokens
:
6
,
},
Model
:
"gpt-5"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
501
,
Quota
:
100
},
User
:
&
User
{
ID
:
601
},
Account
:
&
Account
{
ID
:
701
},
RequestPayloadHash
:
payloadHash
,
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
billingRepo
.
lastCmd
)
require
.
Equal
(
t
,
payloadHash
,
billingRepo
.
lastCmd
.
RequestPayloadHash
)
}
func
TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
userRepo
,
subRepo
,
nil
)
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
RequestID
,
"req-local-fallback"
)
err
:=
svc
.
RecordUsage
(
ctx
,
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
""
,
Usage
:
OpenAIUsage
{
InputTokens
:
8
,
OutputTokens
:
4
,
},
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10047
},
User
:
&
User
{
ID
:
20047
},
Account
:
&
Account
{
ID
:
30047
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
billingRepo
.
lastCmd
)
require
.
Equal
(
t
,
"local:req-local-fallback"
,
billingRepo
.
lastCmd
.
RequestID
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
"local:req-local-fallback"
,
usageRepo
.
lastLog
.
RequestID
)
}
func
TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
userRepo
,
subRepo
,
nil
)
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
ClientRequestID
,
"openai-client-stable-123"
)
err
:=
svc
.
RecordUsage
(
ctx
,
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"upstream-openai-volatile-456"
,
Usage
:
OpenAIUsage
{
InputTokens
:
8
,
OutputTokens
:
4
,
},
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10049
},
User
:
&
User
{
ID
:
20049
},
Account
:
&
Account
{
ID
:
30049
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
billingRepo
.
lastCmd
)
require
.
Equal
(
t
,
"client:openai-client-stable-123"
,
billingRepo
.
lastCmd
.
RequestID
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
"client:openai-client-stable-123"
,
usageRepo
.
lastLog
.
RequestID
)
}
func
TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
result
:
&
UsageBillingApplyResult
{
Applied
:
true
}}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
userRepo
,
subRepo
,
nil
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
""
,
Usage
:
OpenAIUsage
{
InputTokens
:
8
,
OutputTokens
:
4
,
},
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10050
},
User
:
&
User
{
ID
:
20050
},
Account
:
&
Account
{
ID
:
30050
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
billingRepo
.
lastCmd
)
require
.
True
(
t
,
strings
.
HasPrefix
(
billingRepo
.
lastCmd
.
RequestID
,
"generated:"
))
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Equal
(
t
,
billingRepo
.
lastCmd
.
RequestID
,
usageRepo
.
lastLog
.
RequestID
)
}
func
TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{}
billingRepo
:=
&
openAIRecordUsageBillingRepoStub
{
err
:
errors
.
New
(
"billing tx failed"
)}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceWithBillingRepoForTest
(
usageRepo
,
billingRepo
,
userRepo
,
subRepo
,
nil
)
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
OpenAIRecordUsageInput
{
Result
:
&
OpenAIForwardResult
{
RequestID
:
"resp_billing_fail"
,
Usage
:
OpenAIUsage
{
InputTokens
:
8
,
OutputTokens
:
4
,
},
Model
:
"gpt-5.1"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
10048
},
User
:
&
User
{
ID
:
20048
},
Account
:
&
Account
{
ID
:
30048
},
})
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
1
,
billingRepo
.
calls
)
require
.
Equal
(
t
,
0
,
usageRepo
.
calls
)
}
func
TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured
(
t
*
testing
.
T
)
{
...
...
backend/internal/service/openai_gateway_service.go
View file @
01ef7340
...
...
@@ -301,6 +301,7 @@ var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICo
type
OpenAIGatewayService
struct
{
accountRepo
AccountRepository
usageLogRepo
UsageLogRepository
usageBillingRepo
UsageBillingRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
cache
GatewayCache
...
...
@@ -338,6 +339,7 @@ type OpenAIGatewayService struct {
func
NewOpenAIGatewayService
(
accountRepo
AccountRepository
,
usageLogRepo
UsageLogRepository
,
usageBillingRepo
UsageBillingRepository
,
userRepo
UserRepository
,
userSubRepo
UserSubscriptionRepository
,
userGroupRateRepo
UserGroupRateRepository
,
...
...
@@ -355,6 +357,7 @@ func NewOpenAIGatewayService(
svc
:=
&
OpenAIGatewayService
{
accountRepo
:
accountRepo
,
usageLogRepo
:
usageLogRepo
,
usageBillingRepo
:
usageBillingRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
cache
:
cache
,
...
...
@@ -2073,7 +2076,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Build upstream request
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
reqStream
,
promptCacheKey
,
isCodexCLI
)
upstreamCtx
,
releaseUpstreamCtx
:=
detachStreamUpstreamContext
(
ctx
,
reqStream
)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
upstreamCtx
,
c
,
account
,
body
,
token
,
reqStream
,
promptCacheKey
,
isCodexCLI
)
releaseUpstreamCtx
()
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -2265,7 +2270,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return
nil
,
err
}
upstreamReq
,
err
:=
s
.
buildUpstreamRequestOpenAIPassthrough
(
ctx
,
c
,
account
,
body
,
token
)
upstreamCtx
,
releaseUpstreamCtx
:=
detachStreamUpstreamContext
(
ctx
,
reqStream
)
upstreamReq
,
err
:=
s
.
buildUpstreamRequestOpenAIPassthrough
(
upstreamCtx
,
c
,
account
,
body
,
token
)
releaseUpstreamCtx
()
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -2602,6 +2609,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
var
firstTokenMs
*
int
clientDisconnected
:=
false
sawDone
:=
false
sawTerminalEvent
:=
false
upstreamRequestID
:=
strings
.
TrimSpace
(
resp
.
Header
.
Get
(
"x-request-id"
))
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
...
...
@@ -2621,6 +2629,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if
trimmedData
==
"[DONE]"
{
sawDone
=
true
}
if
openAIStreamEventIsTerminal
(
trimmedData
)
{
sawTerminalEvent
=
true
}
if
firstTokenMs
==
nil
&&
trimmedData
!=
""
&&
trimmedData
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
...
...
@@ -2638,19 +2649,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
if
clientDisconnected
{
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v"
,
account
.
ID
,
err
)
if
sawTerminalEvent
{
return
&
openaiStreamingResultPassthrough
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
clientDisconnected
{
return
&
openaiStreamingResultPassthrough
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream usage incomplete after disconnect: %w"
,
err
)
}
if
errors
.
Is
(
err
,
context
.
Canceled
)
||
errors
.
Is
(
err
,
context
.
DeadlineExceeded
)
{
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v"
,
account
.
ID
,
upstreamRequestID
,
err
,
ctx
.
Err
(),
)
return
&
openaiStreamingResultPassthrough
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
return
&
openaiStreamingResultPassthrough
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream usage incomplete: %w"
,
err
)
}
if
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
err
)
...
...
@@ -2664,12 +2670,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
)
return
&
openaiStreamingResultPassthrough
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
}
if
!
clientDisconnected
&&
!
sawDone
&&
ctx
.
Err
()
==
nil
{
if
!
clientDisconnected
&&
!
sawDone
&&
!
sawTerminalEvent
&&
ctx
.
Err
()
==
nil
{
logger
.
FromContext
(
ctx
)
.
With
(
zap
.
String
(
"component"
,
"service.openai_gateway"
),
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
String
(
"upstream_request_id"
,
upstreamRequestID
),
)
.
Info
(
"OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流"
)
return
&
openaiStreamingResultPassthrough
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
errors
.
New
(
"stream usage incomplete: missing terminal event"
)
}
return
&
openaiStreamingResultPassthrough
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
...
...
@@ -3203,6 +3210,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent
:=
false
clientDisconnected
:=
false
// 客户端断开后继续 drain 上游以收集 usage
sawTerminalEvent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
||
clientDisconnected
{
return
...
...
@@ -3233,22 +3241,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"Client disconnected during final flush, returning collected usage"
)
}
}
if
!
sawTerminalEvent
{
return
resultWithUsage
(),
fmt
.
Errorf
(
"stream usage incomplete: missing terminal event"
)
}
return
resultWithUsage
(),
nil
}
handleScanErr
:=
func
(
scanErr
error
)
(
*
openaiStreamingResult
,
error
,
bool
)
{
if
scanErr
==
nil
{
return
nil
,
nil
,
false
}
if
sawTerminalEvent
{
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"Upstream scan ended after terminal event: %v"
,
scanErr
)
return
resultWithUsage
(),
nil
,
true
}
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if
errors
.
Is
(
scanErr
,
context
.
Canceled
)
||
errors
.
Is
(
scanErr
,
context
.
DeadlineExceeded
)
{
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"Context canceled during streaming, returning collected usage"
)
return
resultWithUsage
(),
nil
,
true
return
resultWithUsage
(),
fmt
.
Errorf
(
"stream usage incomplete: %w"
,
scanErr
),
true
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if
clientDisconnected
{
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"Upstream read error after client disconnect: %v, returning collected usage"
,
scanErr
)
return
resultWithUsage
(),
nil
,
true
return
resultWithUsage
(),
fmt
.
Errorf
(
"stream usage incomplete after disconnect: %w"
,
scanErr
),
true
}
if
errors
.
Is
(
scanErr
,
bufio
.
ErrTooLong
)
{
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
scanErr
)
...
...
@@ -3271,6 +3284,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
dataBytes
:=
[]
byte
(
data
)
if
openAIStreamEventIsTerminal
(
data
)
{
sawTerminalEvent
=
true
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEBytes
(
dataBytes
);
corrected
{
...
...
@@ -3387,8 +3403,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
if
clientDisconnected
{
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"Upstream timeout after client disconnect, returning collected usage"
)
return
resultWithUsage
(),
nil
return
resultWithUsage
(),
fmt
.
Errorf
(
"stream usage incomplete after timeout"
)
}
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
// 处理流超时,可能标记账户为临时不可调度或错误状态
...
...
@@ -3486,11 +3501,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
if
usage
==
nil
||
len
(
data
)
==
0
||
bytes
.
Equal
(
data
,
[]
byte
(
"[DONE]"
))
{
return
}
// 选择性解析:仅在数据中包含
completed
事件标识时才进入字段提取。
if
len
(
data
)
<
80
||
!
bytes
.
Contains
(
data
,
[]
byte
(
`"response.completed"`
))
{
// 选择性解析:仅在数据中包含
终止
事件标识时才进入字段提取。
if
len
(
data
)
<
72
{
return
}
if
gjson
.
GetBytes
(
data
,
"type"
)
.
String
()
!=
"response.completed"
{
eventType
:=
gjson
.
GetBytes
(
data
,
"type"
)
.
String
()
if
eventType
!=
"response.completed"
&&
eventType
!=
"response.done"
{
return
}
...
...
@@ -3843,14 +3859,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type
OpenAIRecordUsageInput
struct
{
Result
*
OpenAIForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
APIKeyService
APIKeyQuotaUpdater
Result
*
OpenAIForwardResult
APIKey
*
APIKey
User
*
User
Account
*
Account
Subscription
*
UserSubscription
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
RequestPayloadHash
string
APIKeyService
APIKeyQuotaUpdater
}
// RecordUsage records usage and deducts balance
...
...
@@ -3916,11 +3933,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
requestID
:=
resolveUsageBillingRequestID
(
ctx
,
result
.
RequestID
)
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
re
sult
.
Re
questID
,
RequestID
:
requestID
,
Model
:
billingModel
,
ServiceTier
:
result
.
ServiceTier
,
ReasoningEffort
:
result
.
ReasoningEffort
,
...
...
@@ -3961,29 +3979,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
inserted
,
err
:=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.openai_gateway"
)
logger
.
LegacyPrintf
(
"service.openai_gateway"
,
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
shouldBill
:=
inserted
||
err
!=
nil
if
shouldBill
{
postUsageBilling
(
ctx
,
&
postUsageBillingParams
{
billingErr
:=
func
()
error
{
_
,
err
:=
applyUsageBilling
(
ctx
,
requestID
,
usageLog
,
&
postUsageBillingParams
{
Cost
:
cost
,
User
:
user
,
APIKey
:
apiKey
,
Account
:
account
,
Subscription
:
subscription
,
RequestPayloadHash
:
resolveUsageBillingPayloadFingerprint
(
ctx
,
input
.
RequestPayloadHash
),
IsSubscriptionBill
:
isSubscriptionBilling
,
AccountRateMultiplier
:
accountRateMultiplier
,
APIKeyService
:
input
.
APIKeyService
,
},
s
.
billingDeps
())
}
else
{
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
},
s
.
billingDeps
(),
s
.
usageBillingRepo
)
return
err
}()
if
billingErr
!=
nil
{
return
billingErr
}
writeUsageLogBestEffort
(
ctx
,
s
.
usageLogRepo
,
usageLog
,
"service.openai_gateway"
)
return
nil
}
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
01ef7340
...
...
@@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
}
}
func
TestOpenAIStreamingContextCanceled
DoesNo
tInjectErrorEvent
(
t
*
testing
.
T
)
{
func
TestOpenAIStreamingContextCanceled
ReturnsIncompleteErrorWithou
tInject
ing
ErrorEvent
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
...
...
@@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
}
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
if
err
!
=
nil
{
t
.
Fatalf
(
"expected
nil
error, got %v"
,
err
)
if
err
=
=
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream usage incomplete"
)
{
t
.
Fatalf
(
"expected
incomplete stream
error, got %v"
,
err
)
}
if
strings
.
Contains
(
rec
.
Body
.
String
(),
"event: error"
)
||
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_read_error"
)
{
t
.
Fatalf
(
"expected no injected SSE error event, got %q"
,
rec
.
Body
.
String
())
...
...
@@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
}
}
func
TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.in_progress
\"
,
\"
response
\"
:{}}
\n\n
"
))
}()
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"missing terminal event"
)
{
t
.
Fatalf
(
"expected missing terminal event error, got %v"
,
err
)
}
}
func
TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.in_progress
\"
,
\"
response
\"
:{}}
\n\n
"
))
}()
_
,
err
:=
svc
.
handleStreamingResponsePassthrough
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
())
_
=
pr
.
Close
()
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"missing terminal event"
)
{
t
.
Fatalf
(
"expected missing terminal event error, got %v"
,
err
)
}
}
func
TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.done
\"
,
\"
response
\"
:{
\"
usage
\"
:{
\"
input_tokens
\"
:2,
\"
output_tokens
\"
:3,
\"
input_tokens_details
\"
:{
\"
cached_tokens
\"
:1}}}}
\n\n
"
))
}()
result
,
err
:=
svc
.
handleStreamingResponsePassthrough
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
())
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
usage
)
require
.
Equal
(
t
,
2
,
result
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
3
,
result
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
1
,
result
.
usage
.
CacheReadInputTokens
)
}
func
TestOpenAIStreamingTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
...
...
@@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {}
\n\n
"
))
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.completed
\"
,
\"
response
\"
:{}
}
\n\n
"
))
}()
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
...
...
@@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
require
.
Equal
(
t
,
3
,
usage
.
InputTokens
)
require
.
Equal
(
t
,
5
,
usage
.
OutputTokens
)
require
.
Equal
(
t
,
2
,
usage
.
CacheReadInputTokens
)
// done 事件同样可能携带最终 usage
svc
.
parseSSEUsage
(
`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`
,
usage
)
require
.
Equal
(
t
,
13
,
usage
.
InputTokens
)
require
.
Equal
(
t
,
15
,
usage
.
OutputTokens
)
require
.
Equal
(
t
,
4
,
usage
.
CacheReadInputTokens
)
}
func
TestExtractCodexFinalResponse_SampleReplay
(
t
*
testing
.
T
)
{
...
...
backend/internal/service/openai_oauth_passthrough_test.go
View file @
01ef7340
...
...
@@ -439,7 +439,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/responses"
,
bytes
.
NewReader
(
nil
))
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"codex_cli_rs/0.1.0"
)
originalBody
:=
[]
byte
(
`{"model":"gpt-5.2","stream":
fals
e,"input":[{"type":"text","text":"hi"}]}`
)
originalBody
:=
[]
byte
(
`{"model":"gpt-5.2","stream":
tru
e,"input":[{"type":"text","text":"hi"}]}`
)
headers
:=
make
(
http
.
Header
)
headers
.
Set
(
"Content-Type"
,
"application/json"
)
...
...
@@ -453,7 +453,14 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
headers
,
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`
)),
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
strings
.
Join
([]
string
{
`data: {"type":"response.output_text.delta","delta":"h"}`
,
""
,
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`
,
""
,
"data: [DONE]"
,
""
,
},
"
\n
"
))),
}
upstream
:=
&
httpUpstreamRecorder
{
resp
:
resp
}
...
...
@@ -895,7 +902,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t *
}
_
,
err
:=
svc
.
Forward
(
context
.
Background
(),
c
,
account
,
originalBody
)
require
.
No
Error
(
t
,
err
)
require
.
Equal
Error
(
t
,
err
,
"stream usage incomplete: missing terminal event"
)
require
.
True
(
t
,
logSink
.
ContainsMessage
(
"上游流在未收到 [DONE] 时结束,疑似断流"
))
require
.
True
(
t
,
logSink
.
ContainsMessageAtLevel
(
"上游流在未收到 [DONE] 时结束,疑似断流"
,
"info"
))
require
.
True
(
t
,
logSink
.
ContainsFieldValue
(
"upstream_request_id"
,
"rid-truncate"
))
...
...
@@ -911,11 +918,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *t
c
.
Request
.
Header
.
Set
(
"x-stainless-timeout"
,
"120000"
)
c
.
Request
.
Header
.
Set
(
"X-Test"
,
"keep"
)
originalBody
:=
[]
byte
(
`{"model":"gpt-5.2","stream":
fals
e,"input":[{"type":"text","text":"hi"}]}`
)
originalBody
:=
[]
byte
(
`{"model":"gpt-5.2","stream":
tru
e,"input":[{"type":"text","text":"hi"}]}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
},
"X-Request-Id"
:
[]
string
{
"rid-filter-default"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`
)),
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
},
"X-Request-Id"
:
[]
string
{
"rid-filter-default"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
strings
.
Join
([]
string
{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`
,
""
,
"data: [DONE]"
,
""
,
},
"
\n
"
))),
}
upstream
:=
&
httpUpstreamRecorder
{
resp
:
resp
}
svc
:=
&
OpenAIGatewayService
{
...
...
@@ -952,11 +964,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured
c
.
Request
.
Header
.
Set
(
"x-stainless-timeout"
,
"120000"
)
c
.
Request
.
Header
.
Set
(
"X-Test"
,
"keep"
)
originalBody
:=
[]
byte
(
`{"model":"gpt-5.2","stream":
fals
e,"input":[{"type":"text","text":"hi"}]}`
)
originalBody
:=
[]
byte
(
`{"model":"gpt-5.2","stream":
tru
e,"input":[{"type":"text","text":"hi"}]}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/json"
},
"X-Request-Id"
:
[]
string
{
"rid-filter-allow"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`
)),
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
},
"X-Request-Id"
:
[]
string
{
"rid-filter-allow"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
strings
.
Join
([]
string
{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`
,
""
,
"data: [DONE]"
,
""
,
},
"
\n
"
))),
}
upstream
:=
&
httpUpstreamRecorder
{
resp
:
resp
}
svc
:=
&
OpenAIGatewayService
{
...
...
backend/internal/service/openai_ws_protocol_forward_test.go
View file @
01ef7340
...
...
@@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil
,
nil
,
nil
,
nil
,
cfg
,
nil
,
nil
,
...
...
backend/internal/service/usage_billing.go
0 → 100644
View file @
01ef7340
package
service
import
(
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strings"
)
var
ErrUsageBillingRequestIDRequired
=
errors
.
New
(
"usage billing request_id is required"
)
var
ErrUsageBillingRequestConflict
=
errors
.
New
(
"usage billing request fingerprint conflict"
)
// UsageBillingCommand describes one billable request that must be applied at most once.
type
UsageBillingCommand
struct
{
RequestID
string
APIKeyID
int64
RequestFingerprint
string
RequestPayloadHash
string
UserID
int64
AccountID
int64
SubscriptionID
*
int64
AccountType
string
Model
string
ServiceTier
string
ReasoningEffort
string
BillingType
int8
InputTokens
int
OutputTokens
int
CacheCreationTokens
int
CacheReadTokens
int
ImageCount
int
MediaType
string
BalanceCost
float64
SubscriptionCost
float64
APIKeyQuotaCost
float64
APIKeyRateLimitCost
float64
AccountQuotaCost
float64
}
func
(
c
*
UsageBillingCommand
)
Normalize
()
{
if
c
==
nil
{
return
}
c
.
RequestID
=
strings
.
TrimSpace
(
c
.
RequestID
)
if
strings
.
TrimSpace
(
c
.
RequestFingerprint
)
==
""
{
c
.
RequestFingerprint
=
buildUsageBillingFingerprint
(
c
)
}
}
func
buildUsageBillingFingerprint
(
c
*
UsageBillingCommand
)
string
{
if
c
==
nil
{
return
""
}
raw
:=
fmt
.
Sprintf
(
"%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f"
,
c
.
UserID
,
c
.
AccountID
,
c
.
APIKeyID
,
strings
.
TrimSpace
(
c
.
AccountType
),
strings
.
TrimSpace
(
c
.
Model
),
strings
.
TrimSpace
(
c
.
ServiceTier
),
strings
.
TrimSpace
(
c
.
ReasoningEffort
),
c
.
BillingType
,
c
.
InputTokens
,
c
.
OutputTokens
,
c
.
CacheCreationTokens
,
c
.
CacheReadTokens
,
c
.
ImageCount
,
strings
.
TrimSpace
(
c
.
MediaType
),
valueOrZero
(
c
.
SubscriptionID
),
c
.
BalanceCost
,
c
.
SubscriptionCost
,
c
.
APIKeyQuotaCost
,
c
.
APIKeyRateLimitCost
,
c
.
AccountQuotaCost
,
)
if
payloadHash
:=
strings
.
TrimSpace
(
c
.
RequestPayloadHash
);
payloadHash
!=
""
{
raw
+=
"|"
+
payloadHash
}
sum
:=
sha256
.
Sum256
([]
byte
(
raw
))
return
hex
.
EncodeToString
(
sum
[
:
])
}
func
HashUsageRequestPayload
(
payload
[]
byte
)
string
{
if
len
(
payload
)
==
0
{
return
""
}
sum
:=
sha256
.
Sum256
(
payload
)
return
hex
.
EncodeToString
(
sum
[
:
])
}
func
valueOrZero
(
v
*
int64
)
int64
{
if
v
==
nil
{
return
0
}
return
*
v
}
type
UsageBillingApplyResult
struct
{
Applied
bool
APIKeyQuotaExhausted
bool
}
type
UsageBillingRepository
interface
{
Apply
(
ctx
context
.
Context
,
cmd
*
UsageBillingCommand
)
(
*
UsageBillingApplyResult
,
error
)
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment