Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0b746501
Commit
0b746501
authored
Apr 16, 2026
by
陈曦
Browse files
1. merge upstream v0.1.113 2.提交migration相关文件
parents
45061102
be7551b9
Changes
225
Expand all
Show whitespace changes
Inline
Side-by-side
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
View file @
0b746501
...
...
@@ -28,6 +28,7 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
out
:=
&
ResponsesRequest
{
Model
:
req
.
Model
,
Instructions
:
req
.
Instructions
,
Input
:
inputJSON
,
Temperature
:
req
.
Temperature
,
TopP
:
req
.
TopP
,
...
...
backend/internal/pkg/apicompat/types.go
View file @
0b746501
...
...
@@ -152,6 +152,7 @@ type AnthropicDelta struct {
// ResponsesRequest is the request body for POST /v1/responses.
type
ResponsesRequest
struct
{
Model
string
`json:"model"`
Instructions
string
`json:"instructions,omitempty"`
Input
json
.
RawMessage
`json:"input"`
// string or []ResponsesInputItem
MaxOutputTokens
*
int
`json:"max_output_tokens,omitempty"`
Temperature
*
float64
`json:"temperature,omitempty"`
...
...
@@ -337,6 +338,7 @@ type ResponsesStreamEvent struct {
type
ChatCompletionsRequest
struct
{
Model
string
`json:"model"`
Messages
[]
ChatMessage
`json:"messages"`
Instructions
string
`json:"instructions,omitempty"`
// OpenAI Responses API compat
MaxTokens
*
int
`json:"max_tokens,omitempty"`
MaxCompletionTokens
*
int
`json:"max_completion_tokens,omitempty"`
Temperature
*
float64
`json:"temperature,omitempty"`
...
...
backend/internal/pkg/logger/logger_test.go
View file @
0b746501
...
...
@@ -10,7 +10,13 @@ import (
)
func
TestInit_DualOutput
(
t
*
testing
.
T
)
{
tmpDir
:=
t
.
TempDir
()
// Use os.MkdirTemp instead of t.TempDir to avoid cleanup failures
// when lumberjack holds file handles on Windows.
tmpDir
,
err
:=
os
.
MkdirTemp
(
""
,
"logger-test-*"
)
if
err
!=
nil
{
t
.
Fatalf
(
"create temp dir: %v"
,
err
)
}
t
.
Cleanup
(
func
()
{
_
=
os
.
RemoveAll
(
tmpDir
)
})
logPath
:=
filepath
.
Join
(
tmpDir
,
"logs"
,
"sub2api.log"
)
origStdout
:=
os
.
Stdout
...
...
@@ -57,7 +63,9 @@ func TestInit_DualOutput(t *testing.T) {
L
()
.
Info
(
"dual-output-info"
)
L
()
.
Warn
(
"dual-output-warn"
)
Sync
()
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
// The log data is already in the pipe buffer; closing writers is sufficient.
_
=
stdoutW
.
Close
()
_
=
stderrW
.
Close
()
...
...
@@ -166,7 +174,9 @@ func TestInit_CallerShouldPointToCallsite(t *testing.T) {
}
L
()
.
Info
(
"caller-check"
)
Sync
()
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
os
.
Stdout
=
origStdout
os
.
Stderr
=
origStderr
_
=
stdoutW
.
Close
()
logBytes
,
_
:=
io
.
ReadAll
(
stdoutR
)
...
...
backend/internal/pkg/logger/stdlog_bridge_test.go
View file @
0b746501
...
...
@@ -77,7 +77,7 @@ func TestStdLogBridgeRoutesLevels(t *testing.T) {
log
.
Printf
(
"service started"
)
log
.
Printf
(
"Warning: queue full"
)
log
.
Printf
(
"Forward request failed: timeout"
)
Sync
()
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_
=
stdoutW
.
Close
()
_
=
stderrW
.
Close
()
...
...
@@ -139,7 +139,7 @@ func TestLegacyPrintfRoutesLevels(t *testing.T) {
LegacyPrintf
(
"service.test"
,
"request started"
)
LegacyPrintf
(
"service.test"
,
"Warning: queue full"
)
LegacyPrintf
(
"service.test"
,
"forward failed: timeout"
)
Sync
()
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_
=
stdoutW
.
Close
()
_
=
stderrW
.
Close
()
...
...
backend/internal/pkg/usagestats/usage_log_types.go
View file @
0b746501
...
...
@@ -58,6 +58,7 @@ type DashboardStats struct {
TotalTokens
int64
`json:"total_tokens"`
TotalCost
float64
`json:"total_cost"`
// 累计标准计费
TotalActualCost
float64
`json:"total_actual_cost"`
// 累计实际扣除
TotalAccountCost
float64
`json:"total_account_cost"`
// 累计账号成本
// 今日 Token 使用统计
TodayRequests
int64
`json:"today_requests"`
...
...
@@ -68,6 +69,7 @@ type DashboardStats struct {
TodayTokens
int64
`json:"today_tokens"`
TodayCost
float64
`json:"today_cost"`
// 今日标准计费
TodayActualCost
float64
`json:"today_actual_cost"`
// 今日实际扣除
TodayAccountCost
float64
`json:"today_account_cost"`
// 今日账号成本
// 系统运行统计
AverageDurationMs
float64
`json:"average_duration_ms"`
// 平均响应时间
...
...
@@ -101,6 +103,7 @@ type ModelStat struct {
TotalTokens
int64
`json:"total_tokens"`
Cost
float64
`json:"cost"`
// 标准计费
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
AccountCost
float64
`json:"account_cost"`
// 账号成本
}
// EndpointStat represents usage statistics for a single request endpoint.
...
...
@@ -127,6 +130,7 @@ type GroupStat struct {
TotalTokens
int64
`json:"total_tokens"`
Cost
float64
`json:"cost"`
// 标准计费
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
AccountCost
float64
`json:"account_cost"`
// 账号成本
}
// UserUsageTrendPoint represents user usage trend data point
...
...
@@ -166,6 +170,7 @@ type UserBreakdownItem struct {
TotalTokens
int64
`json:"total_tokens"`
Cost
float64
`json:"cost"`
// 标准计费
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
AccountCost
float64
`json:"account_cost"`
// 账号成本
}
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
...
...
backend/internal/pkg/websearch/brave.go
0 → 100644
View file @
0b746501
package
websearch
import
(
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
)
const
(
braveSearchEndpoint
=
"https://api.search.brave.com/res/v1/web/search"
braveMaxCount
=
20
braveProviderName
=
"brave"
)
// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
var
braveSearchURL
,
_
=
url
.
Parse
(
braveSearchEndpoint
)
//nolint:errcheck
// BraveProvider implements web search via the Brave Search API.
type
BraveProvider
struct
{
apiKey
string
httpClient
*
http
.
Client
}
// NewBraveProvider creates a Brave Search provider.
// The caller is responsible for configuring the http.Client with proxy/timeouts.
func
NewBraveProvider
(
apiKey
string
,
httpClient
*
http
.
Client
)
*
BraveProvider
{
if
httpClient
==
nil
{
httpClient
=
http
.
DefaultClient
}
return
&
BraveProvider
{
apiKey
:
apiKey
,
httpClient
:
httpClient
}
}
func
(
b
*
BraveProvider
)
Name
()
string
{
return
braveProviderName
}
func
(
b
*
BraveProvider
)
Search
(
ctx
context
.
Context
,
req
SearchRequest
)
(
*
SearchResponse
,
error
)
{
count
:=
req
.
MaxResults
if
count
<=
0
{
count
=
defaultMaxResults
}
if
count
>
braveMaxCount
{
count
=
braveMaxCount
}
u
:=
*
braveSearchURL
// copy the pre-parsed URL
q
:=
u
.
Query
()
q
.
Set
(
"q"
,
req
.
Query
)
q
.
Set
(
"count"
,
strconv
.
Itoa
(
count
))
u
.
RawQuery
=
q
.
Encode
()
httpReq
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodGet
,
u
.
String
(),
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"brave: build request: %w"
,
err
)
}
httpReq
.
Header
.
Set
(
"X-Subscription-Token"
,
b
.
apiKey
)
httpReq
.
Header
.
Set
(
"Accept"
,
"application/json"
)
resp
,
err
:=
b
.
httpClient
.
Do
(
httpReq
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"brave: request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
maxResponseSize
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"brave: read body: %w"
,
err
)
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
fmt
.
Errorf
(
"brave: status %d: %s"
,
resp
.
StatusCode
,
truncateBody
(
body
))
}
var
raw
braveResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
raw
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"brave: decode response: %w"
,
err
)
}
results
:=
make
([]
SearchResult
,
0
,
len
(
raw
.
Web
.
Results
))
for
_
,
r
:=
range
raw
.
Web
.
Results
{
results
=
append
(
results
,
SearchResult
{
URL
:
r
.
URL
,
Title
:
r
.
Title
,
Snippet
:
r
.
Description
,
PageAge
:
r
.
Age
,
})
}
return
&
SearchResponse
{
Results
:
results
,
Query
:
req
.
Query
},
nil
}
// braveResponse is the minimal structure of the Brave Search API response.
type
braveResponse
struct
{
Web
struct
{
Results
[]
braveResult
`json:"results"`
}
`json:"web"`
}
type
braveResult
struct
{
URL
string
`json:"url"`
Title
string
`json:"title"`
Description
string
`json:"description"`
Age
string
`json:"age"`
}
backend/internal/pkg/websearch/brave_test.go
0 → 100644
View file @
0b746501
package
websearch
import
(
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func
TestBraveProvider_Name
(
t
*
testing
.
T
)
{
p
:=
NewBraveProvider
(
"key"
,
nil
)
require
.
Equal
(
t
,
"brave"
,
p
.
Name
())
}
func
TestBraveProvider_Search_Success
(
t
*
testing
.
T
)
{
srv
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
Equal
(
t
,
"test-key"
,
r
.
Header
.
Get
(
"X-Subscription-Token"
))
require
.
Equal
(
t
,
"application/json"
,
r
.
Header
.
Get
(
"Accept"
))
require
.
Equal
(
t
,
"golang"
,
r
.
URL
.
Query
()
.
Get
(
"q"
))
require
.
Equal
(
t
,
"3"
,
r
.
URL
.
Query
()
.
Get
(
"count"
))
resp
:=
braveResponse
{}
resp
.
Web
.
Results
=
[]
braveResult
{
{
URL
:
"https://go.dev"
,
Title
:
"Go"
,
Description
:
"Go lang"
,
Age
:
"1 day"
},
{
URL
:
"https://pkg.go.dev"
,
Title
:
"Pkg"
,
Description
:
"Packages"
},
{
URL
:
"https://tour.go.dev"
,
Title
:
"Tour"
,
Description
:
"A Tour of Go"
,
Age
:
"3 days"
},
}
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
resp
)
}))
defer
srv
.
Close
()
p
:=
NewBraveProvider
(
"test-key"
,
srv
.
Client
())
// Override the endpoint for testing
origURL
:=
*
braveSearchURL
u
,
_
:=
http
.
NewRequest
(
"GET"
,
srv
.
URL
,
nil
)
*
braveSearchURL
=
*
u
.
URL
defer
func
()
{
*
braveSearchURL
=
origURL
}()
resp
,
err
:=
p
.
Search
(
context
.
Background
(),
SearchRequest
{
Query
:
"golang"
,
MaxResults
:
3
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
resp
.
Results
,
3
)
require
.
Equal
(
t
,
"https://go.dev"
,
resp
.
Results
[
0
]
.
URL
)
require
.
Equal
(
t
,
"Go lang"
,
resp
.
Results
[
0
]
.
Snippet
)
require
.
Equal
(
t
,
"1 day"
,
resp
.
Results
[
0
]
.
PageAge
)
}
func
TestBraveProvider_Search_DefaultMaxResults
(
t
*
testing
.
T
)
{
var
receivedCount
string
srv
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
receivedCount
=
r
.
URL
.
Query
()
.
Get
(
"count"
)
resp
:=
braveResponse
{}
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
resp
)
}))
defer
srv
.
Close
()
p
:=
NewBraveProvider
(
"key"
,
srv
.
Client
())
origURL
:=
*
braveSearchURL
u
,
_
:=
http
.
NewRequest
(
"GET"
,
srv
.
URL
,
nil
)
*
braveSearchURL
=
*
u
.
URL
defer
func
()
{
*
braveSearchURL
=
origURL
}()
_
,
_
=
p
.
Search
(
context
.
Background
(),
SearchRequest
{
Query
:
"test"
,
MaxResults
:
0
})
require
.
Equal
(
t
,
"5"
,
receivedCount
)
}
func
TestBraveProvider_Search_HTTPError
(
t
*
testing
.
T
)
{
srv
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
_
*
http
.
Request
)
{
w
.
WriteHeader
(
429
)
_
,
_
=
w
.
Write
([]
byte
(
"rate limited"
))
}))
defer
srv
.
Close
()
p
:=
NewBraveProvider
(
"key"
,
srv
.
Client
())
origURL
:=
*
braveSearchURL
u
,
_
:=
http
.
NewRequest
(
"GET"
,
srv
.
URL
,
nil
)
*
braveSearchURL
=
*
u
.
URL
defer
func
()
{
*
braveSearchURL
=
origURL
}()
_
,
err
:=
p
.
Search
(
context
.
Background
(),
SearchRequest
{
Query
:
"test"
})
require
.
ErrorContains
(
t
,
err
,
"brave: status 429"
)
}
func
TestBraveProvider_Search_InvalidJSON
(
t
*
testing
.
T
)
{
srv
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
_
*
http
.
Request
)
{
_
,
_
=
w
.
Write
([]
byte
(
"not json"
))
}))
defer
srv
.
Close
()
p
:=
NewBraveProvider
(
"key"
,
srv
.
Client
())
origURL
:=
*
braveSearchURL
u
,
_
:=
http
.
NewRequest
(
"GET"
,
srv
.
URL
,
nil
)
*
braveSearchURL
=
*
u
.
URL
defer
func
()
{
*
braveSearchURL
=
origURL
}()
_
,
err
:=
p
.
Search
(
context
.
Background
(),
SearchRequest
{
Query
:
"test"
})
require
.
ErrorContains
(
t
,
err
,
"brave: decode response"
)
}
func
TestBraveProvider_Search_EmptyResults
(
t
*
testing
.
T
)
{
srv
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
_
*
http
.
Request
)
{
resp
:=
braveResponse
{}
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
resp
)
}))
defer
srv
.
Close
()
p
:=
NewBraveProvider
(
"key"
,
srv
.
Client
())
origURL
:=
*
braveSearchURL
u
,
_
:=
http
.
NewRequest
(
"GET"
,
srv
.
URL
,
nil
)
*
braveSearchURL
=
*
u
.
URL
defer
func
()
{
*
braveSearchURL
=
origURL
}()
resp
,
err
:=
p
.
Search
(
context
.
Background
(),
SearchRequest
{
Query
:
"test"
})
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
resp
.
Results
)
}
backend/internal/pkg/websearch/helpers.go
0 → 100644
View file @
0b746501
package
websearch
const
(
maxResponseSize
=
1
<<
20
// 1 MB
errorBodyTruncLen
=
200
)
// truncateBody returns a truncated string of body for error messages.
func
truncateBody
(
body
[]
byte
)
string
{
if
len
(
body
)
<=
errorBodyTruncLen
{
return
string
(
body
)
}
return
string
(
body
[
:
errorBodyTruncLen
])
+
"...(truncated)"
}
backend/internal/pkg/websearch/helpers_test.go
0 → 100644
View file @
0b746501
package
websearch
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestTruncateBody_Short
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
"short body"
)
require
.
Equal
(
t
,
"short body"
,
truncateBody
(
body
))
}
func
TestTruncateBody_Long
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
strings
.
Repeat
(
"x"
,
500
))
result
:=
truncateBody
(
body
)
require
.
Len
(
t
,
result
,
errorBodyTruncLen
+
len
(
"...(truncated)"
))
require
.
True
(
t
,
strings
.
HasSuffix
(
result
,
"...(truncated)"
))
}
func
TestTruncateBody_ExactBoundary
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
strings
.
Repeat
(
"x"
,
errorBodyTruncLen
))
require
.
Equal
(
t
,
string
(
body
),
truncateBody
(
body
))
}
backend/internal/pkg/websearch/manager.go
0 → 100644
View file @
0b746501
package
websearch
import
(
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"math/rand"
"net"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/redis/go-redis/v9"
)
// ProviderConfig holds the configuration for a single search provider.
type
ProviderConfig
struct
{
Type
string
`json:"type"`
// ProviderTypeBrave | ProviderTypeTavily
APIKey
string
`json:"api_key"`
// secret
QuotaLimit
int64
`json:"quota_limit"`
// 0 = unlimited
SubscribedAt
*
int64
`json:"subscribed_at,omitempty"`
// subscription start (unix seconds); quota resets monthly from this date
ProxyURL
string
`json:"-"`
// resolved proxy URL (not persisted)
ProxyID
int64
`json:"-"`
// resolved proxy ID for unavailability tracking
ExpiresAt
*
int64
`json:"expires_at,omitempty"`
// optional expiration (unix seconds)
}
// Manager selects providers by quota-weighted load balancing and tracks quota via Redis.
type
Manager
struct
{
configs
[]
ProviderConfig
redis
*
redis
.
Client
clientMu
sync
.
Mutex
clientCache
map
[
string
]
*
http
.
Client
}
// Timeout constants for proxy and search operations.
const
(
proxyDialTimeout
=
3
*
time
.
Second
// proxy TCP connection timeout
proxyTLSTimeout
=
3
*
time
.
Second
// TLS handshake timeout
searchDataTimeout
=
60
*
time
.
Second
// response data transfer timeout
searchRequestTimeout
=
searchDataTimeout
+
proxyDialTimeout
quotaKeyPrefix
=
"websearch:quota:"
proxyUnavailableKey
=
"websearch:proxy_unavailable:%d"
proxyUnavailableTTL
=
5
*
time
.
Minute
quotaTTLBuffer
=
24
*
time
.
Hour
defaultQuotaTTL
=
31
*
24
*
time
.
Hour
+
quotaTTLBuffer
// fallback when no subscription date
maxCachedClients
=
100
)
// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue.
// Callers may use this to trigger account switching instead of direct fallback.
var
ErrProxyUnavailable
=
errors
.
New
(
"websearch: proxy unavailable"
)
// quotaIncrScript atomically increments the counter and sets TTL on first creation.
var
quotaIncrScript
=
redis
.
NewScript
(
`
local val = redis.call('INCR', KEYS[1])
if val == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[1])
else
local ttl = redis.call('TTL', KEYS[1])
if ttl == -1 then
redis.call('EXPIRE', KEYS[1], ARGV[1])
end
end
return val
`
)
// NewManager creates a Manager with the given provider configs and Redis client.
// Provider order is preserved as-is; selectByQuotaWeight handles load balancing.
func
NewManager
(
configs
[]
ProviderConfig
,
redisClient
*
redis
.
Client
)
*
Manager
{
copied
:=
make
([]
ProviderConfig
,
len
(
configs
))
copy
(
copied
,
configs
)
return
&
Manager
{
configs
:
copied
,
redis
:
redisClient
,
clientCache
:
make
(
map
[
string
]
*
http
.
Client
),
}
}
// SearchWithBestProvider selects a provider using quota-weighted load balancing,
// reserves quota, executes the search, and rolls back quota on failure.
// If the search fails due to a proxy error, the proxy is marked unavailable for 5 minutes.
func
(
m
*
Manager
)
SearchWithBestProvider
(
ctx
context
.
Context
,
req
SearchRequest
)
(
*
SearchResponse
,
string
,
error
)
{
if
strings
.
TrimSpace
(
req
.
Query
)
==
""
{
return
nil
,
""
,
fmt
.
Errorf
(
"websearch: empty search query"
)
}
candidates
:=
m
.
filterAvailableProviders
(
ctx
,
req
.
ProxyURL
)
if
len
(
candidates
)
==
0
{
return
nil
,
""
,
fmt
.
Errorf
(
"websearch: no available provider (all exhausted, expired, or proxy unavailable)"
)
}
selected
:=
m
.
selectByQuotaWeight
(
ctx
,
candidates
)
for
_
,
cfg
:=
range
selected
{
allowed
,
incremented
:=
m
.
tryReserveQuota
(
ctx
,
cfg
)
if
!
allowed
{
continue
}
resp
,
err
:=
m
.
executeSearch
(
ctx
,
cfg
,
req
)
if
err
!=
nil
{
if
incremented
{
m
.
rollbackQuota
(
ctx
,
cfg
)
}
if
isProxyError
(
err
)
{
m
.
markProxyUnavailable
(
ctx
,
cfg
,
req
.
ProxyURL
)
if
req
.
ProxyURL
!=
""
{
// Account-level proxy is shared by all providers — no point
// trying others with the same broken proxy; signal account switch.
slog
.
Warn
(
"websearch: account proxy error, aborting failover"
,
"provider"
,
cfg
.
Type
,
"error"
,
err
)
return
nil
,
""
,
fmt
.
Errorf
(
"%w: %s"
,
ErrProxyUnavailable
,
err
.
Error
())
}
// Provider-specific proxy failed — try the next provider which
// may use a different (or no) proxy.
slog
.
Warn
(
"websearch: provider proxy error, trying next provider"
,
"provider"
,
cfg
.
Type
,
"error"
,
err
)
continue
}
slog
.
Warn
(
"websearch: provider search failed"
,
"provider"
,
cfg
.
Type
,
"error"
,
err
)
continue
}
return
resp
,
cfg
.
Type
,
nil
}
return
nil
,
""
,
fmt
.
Errorf
(
"websearch: no available provider (all exhausted or failed)"
)
}
// filterAvailableProviders returns providers that have API keys, are not expired,
// and whose proxies are not marked unavailable.
func
(
m
*
Manager
)
filterAvailableProviders
(
ctx
context
.
Context
,
accountProxyURL
string
)
[]
ProviderConfig
{
var
out
[]
ProviderConfig
for
_
,
cfg
:=
range
m
.
configs
{
if
!
m
.
isProviderAvailable
(
cfg
)
{
continue
}
proxyID
:=
resolveProxyID
(
cfg
,
accountProxyURL
)
if
proxyID
>
0
&&
!
m
.
isProxyAvailable
(
ctx
,
proxyID
)
{
slog
.
Debug
(
"websearch: proxy marked unavailable, skipping"
,
"provider"
,
cfg
.
Type
,
"proxy_id"
,
proxyID
)
continue
}
out
=
append
(
out
,
cfg
)
}
return
out
}
// weighted is a provider candidate with computed quota weight.
type
weighted
struct
{
cfg
ProviderConfig
weight
int64
}
// selectByQuotaWeight orders candidates by remaining quota weight.
// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
// Among providers with quota, higher remaining quota = higher priority.
func
(
m
*
Manager
)
selectByQuotaWeight
(
ctx
context
.
Context
,
candidates
[]
ProviderConfig
)
[]
ProviderConfig
{
items
:=
m
.
computeWeights
(
ctx
,
candidates
)
withQuota
,
withoutQuota
:=
partitionByQuota
(
items
)
sortByStableRandomWeight
(
withQuota
)
return
mergeWeightedResults
(
withQuota
,
withoutQuota
,
len
(
candidates
))
}
func
(
m
*
Manager
)
computeWeights
(
ctx
context
.
Context
,
candidates
[]
ProviderConfig
)
[]
weighted
{
items
:=
make
([]
weighted
,
0
,
len
(
candidates
))
for
_
,
cfg
:=
range
candidates
{
w
:=
int64
(
0
)
if
cfg
.
QuotaLimit
>
0
{
used
,
_
:=
m
.
GetUsage
(
ctx
,
cfg
.
Type
)
if
remaining
:=
cfg
.
QuotaLimit
-
used
;
remaining
>
0
{
w
=
remaining
}
}
items
=
append
(
items
,
weighted
{
cfg
:
cfg
,
weight
:
w
})
}
return
items
}
func
partitionByQuota
(
items
[]
weighted
)
(
withQuota
,
withoutQuota
[]
weighted
)
{
for
_
,
item
:=
range
items
{
if
item
.
weight
>
0
{
withQuota
=
append
(
withQuota
,
item
)
}
else
{
withoutQuota
=
append
(
withoutQuota
,
item
)
}
}
return
}
// sortByStableRandomWeight assigns a fixed random factor to each item before sorting,
// ensuring deterministic sort behavior (transitivity) within a single call.
func
sortByStableRandomWeight
(
items
[]
weighted
)
{
if
len
(
items
)
<=
1
{
return
}
type
entry
struct
{
item
weighted
factor
float64
}
entries
:=
make
([]
entry
,
len
(
items
))
for
i
,
item
:=
range
items
{
entries
[
i
]
=
entry
{
item
:
item
,
factor
:
float64
(
item
.
weight
)
*
(
0.5
+
rand
.
Float64
())}
}
sort
.
Slice
(
entries
,
func
(
i
,
j
int
)
bool
{
return
entries
[
i
]
.
factor
>
entries
[
j
]
.
factor
})
for
i
,
e
:=
range
entries
{
items
[
i
]
=
e
.
item
}
}
func
mergeWeightedResults
(
withQuota
,
withoutQuota
[]
weighted
,
capacity
int
)
[]
ProviderConfig
{
result
:=
make
([]
ProviderConfig
,
0
,
capacity
)
for
_
,
item
:=
range
withQuota
{
result
=
append
(
result
,
item
.
cfg
)
}
for
_
,
item
:=
range
withoutQuota
{
result
=
append
(
result
,
item
.
cfg
)
}
return
result
}
func
(
m
*
Manager
)
isProviderAvailable
(
cfg
ProviderConfig
)
bool
{
if
cfg
.
APIKey
==
""
{
return
false
}
if
cfg
.
ExpiresAt
!=
nil
&&
time
.
Now
()
.
Unix
()
>
*
cfg
.
ExpiresAt
{
slog
.
Info
(
"websearch: provider expired, skipping"
,
"provider"
,
cfg
.
Type
,
"expires_at"
,
*
cfg
.
ExpiresAt
)
return
false
}
return
true
}
// --- Proxy availability tracking ---
// markProxyUnavailable marks the effective proxy as unavailable for proxyUnavailableTTL.
func
(
m
*
Manager
)
markProxyUnavailable
(
ctx
context
.
Context
,
cfg
ProviderConfig
,
accountProxyURL
string
)
{
proxyID
:=
resolveProxyID
(
cfg
,
accountProxyURL
)
if
proxyID
<=
0
||
m
.
redis
==
nil
{
return
}
key
:=
fmt
.
Sprintf
(
proxyUnavailableKey
,
proxyID
)
if
err
:=
m
.
redis
.
Set
(
ctx
,
key
,
"1"
,
proxyUnavailableTTL
)
.
Err
();
err
!=
nil
{
slog
.
Warn
(
"websearch: failed to mark proxy unavailable"
,
"proxy_id"
,
proxyID
,
"error"
,
err
)
}
}
// isProxyAvailable checks whether a proxy is currently marked as unavailable.
func
(
m
*
Manager
)
isProxyAvailable
(
ctx
context
.
Context
,
proxyID
int64
)
bool
{
if
m
.
redis
==
nil
||
proxyID
<=
0
{
return
true
}
key
:=
fmt
.
Sprintf
(
proxyUnavailableKey
,
proxyID
)
val
,
err
:=
m
.
redis
.
Get
(
ctx
,
key
)
.
Result
()
if
err
!=
nil
{
return
true
// Redis error → assume available
}
return
val
==
""
}
// resolveProxyID determines the effective proxy ID for a provider+account combination.
func
resolveProxyID
(
cfg
ProviderConfig
,
accountProxyURL
string
)
int64
{
if
accountProxyURL
!=
""
{
return
0
// account proxy has no ID in provider config
}
return
cfg
.
ProxyID
}
// isProxyError checks whether the error is likely caused by proxy or network connectivity
// (as opposed to an API-level error from the search provider).
func
isProxyError
(
err
error
)
bool
{
if
err
==
nil
{
return
false
}
// Network-level errors (timeout, connection refused, DNS failure)
var
netErr
net
.
Error
if
errors
.
As
(
err
,
&
netErr
)
{
return
true
}
var
opErr
*
net
.
OpError
if
errors
.
As
(
err
,
&
opErr
)
{
return
true
}
// TLS handshake failures (often caused by proxy intercepting/blocking)
var
tlsErr
*
tls
.
RecordHeaderError
if
errors
.
As
(
err
,
&
tlsErr
)
{
return
true
}
// String-based detection for wrapped errors
msg
:=
strings
.
ToLower
(
err
.
Error
())
return
strings
.
Contains
(
msg
,
"proxy"
)
||
strings
.
Contains
(
msg
,
"socks"
)
||
strings
.
Contains
(
msg
,
"connection refused"
)
||
strings
.
Contains
(
msg
,
"no such host"
)
||
strings
.
Contains
(
msg
,
"i/o timeout"
)
||
strings
.
Contains
(
msg
,
"tls handshake"
)
||
strings
.
Contains
(
msg
,
"certificate"
)
}
// --- Quota management ---
func
(
m
*
Manager
)
tryReserveQuota
(
ctx
context
.
Context
,
cfg
ProviderConfig
)
(
bool
,
bool
)
{
if
cfg
.
QuotaLimit
<=
0
{
return
true
,
false
}
if
m
.
redis
==
nil
{
slog
.
Warn
(
"websearch: Redis unavailable, quota check skipped"
,
"provider"
,
cfg
.
Type
)
return
true
,
false
}
key
:=
quotaRedisKey
(
cfg
.
Type
)
ttlSec
:=
int
(
quotaTTLFromSubscription
(
cfg
.
SubscribedAt
)
.
Seconds
())
newVal
,
err
:=
quotaIncrScript
.
Run
(
ctx
,
m
.
redis
,
[]
string
{
key
},
ttlSec
)
.
Int64
()
if
err
!=
nil
{
slog
.
Warn
(
"websearch: quota Lua INCR failed, allowing request"
,
"provider"
,
cfg
.
Type
,
"error"
,
err
)
return
true
,
false
}
if
newVal
>
cfg
.
QuotaLimit
{
if
decrErr
:=
m
.
redis
.
Decr
(
ctx
,
key
)
.
Err
();
decrErr
!=
nil
{
slog
.
Warn
(
"websearch: quota over-limit DECR failed"
,
"provider"
,
cfg
.
Type
,
"error"
,
decrErr
)
}
slog
.
Info
(
"websearch: provider quota exhausted"
,
"provider"
,
cfg
.
Type
,
"used"
,
newVal
,
"limit"
,
cfg
.
QuotaLimit
)
return
false
,
false
}
return
true
,
true
}
func
(
m
*
Manager
)
rollbackQuota
(
ctx
context
.
Context
,
cfg
ProviderConfig
)
{
if
cfg
.
QuotaLimit
<=
0
||
m
.
redis
==
nil
{
return
}
key
:=
quotaRedisKey
(
cfg
.
Type
)
if
err
:=
m
.
redis
.
Decr
(
ctx
,
key
)
.
Err
();
err
!=
nil
{
slog
.
Warn
(
"websearch: quota rollback DECR failed"
,
"provider"
,
cfg
.
Type
,
"error"
,
err
)
}
}
// --- Search execution ---
// TestSearch executes a search using the first available provider without reserving quota.
// Intended for admin test functionality only.
func
(
m
*
Manager
)
TestSearch
(
ctx
context
.
Context
,
req
SearchRequest
)
(
*
SearchResponse
,
string
,
error
)
{
if
strings
.
TrimSpace
(
req
.
Query
)
==
""
{
return
nil
,
""
,
fmt
.
Errorf
(
"websearch: empty search query"
)
}
for
_
,
cfg
:=
range
m
.
configs
{
if
!
m
.
isProviderAvailable
(
cfg
)
{
continue
}
resp
,
err
:=
m
.
executeSearch
(
ctx
,
cfg
,
req
)
if
err
!=
nil
{
continue
}
return
resp
,
cfg
.
Type
,
nil
}
return
nil
,
""
,
fmt
.
Errorf
(
"websearch: no available provider"
)
}
func
(
m
*
Manager
)
executeSearch
(
ctx
context
.
Context
,
cfg
ProviderConfig
,
req
SearchRequest
)
(
*
SearchResponse
,
error
)
{
proxyURL
:=
cfg
.
ProxyURL
if
req
.
ProxyURL
!=
""
{
proxyURL
=
req
.
ProxyURL
}
client
,
err
:=
m
.
getOrCreateHTTPClient
(
proxyURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"websearch: %w"
,
err
)
}
provider
:=
m
.
buildProvider
(
cfg
,
client
)
return
provider
.
Search
(
ctx
,
req
)
}
// --- HTTP client cache ---
func
(
m
*
Manager
)
getOrCreateHTTPClient
(
proxyURL
string
)
(
*
http
.
Client
,
error
)
{
m
.
clientMu
.
Lock
()
defer
m
.
clientMu
.
Unlock
()
if
c
,
ok
:=
m
.
clientCache
[
proxyURL
];
ok
{
return
c
,
nil
}
if
len
(
m
.
clientCache
)
>=
maxCachedClients
{
m
.
clientCache
=
make
(
map
[
string
]
*
http
.
Client
)
}
c
,
err
:=
newHTTPClient
(
proxyURL
)
if
err
!=
nil
{
return
nil
,
err
}
m
.
clientCache
[
proxyURL
]
=
c
return
c
,
nil
}
// newHTTPClient creates an HTTP client with proper timeout settings.
// Uses proxyutil.ConfigureTransportProxy for unified proxy protocol support
// (HTTP/HTTPS/SOCKS5/SOCKS5H).
// Returns error if proxyURL is invalid — never falls back to direct connection.
func
newHTTPClient
(
proxyURL
string
)
(
*
http
.
Client
,
error
)
{
transport
:=
&
http
.
Transport
{
TLSClientConfig
:
&
tls
.
Config
{
MinVersion
:
tls
.
VersionTLS12
},
DialContext
:
(
&
net
.
Dialer
{
Timeout
:
proxyDialTimeout
})
.
DialContext
,
TLSHandshakeTimeout
:
proxyTLSTimeout
,
ResponseHeaderTimeout
:
searchDataTimeout
,
}
if
proxyURL
!=
""
{
parsed
,
err
:=
url
.
Parse
(
proxyURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid proxy URL %q: %w"
,
proxyURL
,
err
)
}
if
err
:=
proxyutil
.
ConfigureTransportProxy
(
transport
,
parsed
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"configure proxy: %w"
,
err
)
}
}
return
&
http
.
Client
{
Transport
:
transport
,
Timeout
:
searchRequestTimeout
},
nil
}
// GetUsage returns the current usage count for the given provider.
func
(
m
*
Manager
)
GetUsage
(
ctx
context
.
Context
,
providerType
string
)
(
int64
,
error
)
{
if
m
.
redis
==
nil
{
return
0
,
nil
}
key
:=
quotaRedisKey
(
providerType
)
val
,
err
:=
m
.
redis
.
Get
(
ctx
,
key
)
.
Int64
()
if
err
==
redis
.
Nil
{
return
0
,
nil
}
return
val
,
err
}
// GetAllUsage returns usage for every configured provider.
func
(
m
*
Manager
)
GetAllUsage
(
ctx
context
.
Context
)
map
[
string
]
int64
{
result
:=
make
(
map
[
string
]
int64
,
len
(
m
.
configs
))
for
_
,
cfg
:=
range
m
.
configs
{
used
,
_
:=
m
.
GetUsage
(
ctx
,
cfg
.
Type
)
result
[
cfg
.
Type
]
=
used
}
return
result
}
// ResetUsage deletes the Redis quota key for the given provider, resetting usage to 0.
func
(
m
*
Manager
)
ResetUsage
(
ctx
context
.
Context
,
providerType
string
)
error
{
if
m
.
redis
==
nil
{
return
nil
}
key
:=
quotaRedisKey
(
providerType
)
return
m
.
redis
.
Del
(
ctx
,
key
)
.
Err
()
}
// --- Provider factory ---
func
(
m
*
Manager
)
buildProvider
(
cfg
ProviderConfig
,
client
*
http
.
Client
)
Provider
{
switch
cfg
.
Type
{
case
braveProviderName
:
return
NewBraveProvider
(
cfg
.
APIKey
,
client
)
case
tavilyProviderName
:
return
NewTavilyProvider
(
cfg
.
APIKey
,
client
)
default
:
slog
.
Warn
(
"websearch: unknown provider type, falling back to brave"
,
"type"
,
cfg
.
Type
)
return
NewBraveProvider
(
cfg
.
APIKey
,
client
)
}
}
// --- Redis key helpers ---
func
quotaRedisKey
(
providerType
string
)
string
{
return
quotaKeyPrefix
+
providerType
}
// quotaTTLFromSubscription calculates the TTL for the quota counter based on
// the provider's subscription start date. Quota resets monthly from that date.
// When the Redis key expires naturally, the next INCR creates a fresh counter (lazy refresh).
func
quotaTTLFromSubscription
(
subscribedAt
*
int64
)
time
.
Duration
{
if
subscribedAt
==
nil
||
*
subscribedAt
==
0
{
return
defaultQuotaTTL
}
next
:=
nextMonthlyReset
(
time
.
Unix
(
*
subscribedAt
,
0
)
.
UTC
())
ttl
:=
time
.
Until
(
next
)
+
quotaTTLBuffer
if
ttl
<=
quotaTTLBuffer
{
// Already past the reset — next cycle
ttl
=
defaultQuotaTTL
}
return
ttl
}
// nextMonthlyReset returns the next monthly reset time based on the subscription start date.
// E.g., subscribed on Jan 15 → resets on Feb 15, Mar 15, etc.
// Handles day-of-month overflow: Jan 31 → Feb 28 (not Mar 3).
func
nextMonthlyReset
(
subscribedAt
time
.
Time
)
time
.
Time
{
now
:=
time
.
Now
()
.
UTC
()
if
subscribedAt
.
IsZero
()
{
return
now
.
AddDate
(
0
,
1
,
0
)
}
months
:=
(
now
.
Year
()
-
subscribedAt
.
Year
())
*
12
+
int
(
now
.
Month
()
-
subscribedAt
.
Month
())
if
months
<
0
{
months
=
0
}
candidate
:=
addMonthsClamped
(
subscribedAt
,
months
)
if
candidate
.
After
(
now
)
{
return
candidate
}
return
addMonthsClamped
(
subscribedAt
,
months
+
1
)
}
// addMonthsClamped adds N months to a date, clamping the day to the last day of the target month.
// E.g., Jan 31 + 1 month = Feb 28 (not Mar 3).
func
addMonthsClamped
(
t
time
.
Time
,
months
int
)
time
.
Time
{
y
,
m
,
d
:=
t
.
Date
()
targetMonth
:=
time
.
Month
(
int
(
m
)
+
months
)
targetYear
:=
y
+
int
(
targetMonth
-
1
)
/
12
targetMonth
=
(
targetMonth
-
1
)
%
12
+
1
// Last day of the target month
lastDay
:=
time
.
Date
(
targetYear
,
targetMonth
+
1
,
0
,
0
,
0
,
0
,
0
,
time
.
UTC
)
.
Day
()
if
d
>
lastDay
{
d
=
lastDay
}
return
time
.
Date
(
targetYear
,
targetMonth
,
d
,
0
,
0
,
0
,
0
,
time
.
UTC
)
}
backend/internal/pkg/websearch/manager_test.go
0 → 100644
View file @
0b746501
package
websearch
import
(
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func
TestNewManager_PreservesOrder
(
t
*
testing
.
T
)
{
configs
:=
[]
ProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"k3"
},
{
Type
:
"tavily"
,
APIKey
:
"k1"
},
}
m
:=
NewManager
(
configs
,
nil
)
require
.
Equal
(
t
,
"brave"
,
m
.
configs
[
0
]
.
Type
)
require
.
Equal
(
t
,
"tavily"
,
m
.
configs
[
1
]
.
Type
)
}
func
TestManager_SearchWithBestProvider_EmptyQuery
(
t
*
testing
.
T
)
{
m
:=
NewManager
([]
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"k"
}},
nil
)
_
,
_
,
err
:=
m
.
SearchWithBestProvider
(
context
.
Background
(),
SearchRequest
{
Query
:
""
})
require
.
ErrorContains
(
t
,
err
,
"empty search query"
)
_
,
_
,
err
=
m
.
SearchWithBestProvider
(
context
.
Background
(),
SearchRequest
{
Query
:
" "
})
require
.
ErrorContains
(
t
,
err
,
"empty search query"
)
}
func
TestManager_SearchWithBestProvider_SkipEmptyAPIKey
(
t
*
testing
.
T
)
{
m
:=
NewManager
([]
ProviderConfig
{{
Type
:
"brave"
,
APIKey
:
""
}},
nil
)
_
,
_
,
err
:=
m
.
SearchWithBestProvider
(
context
.
Background
(),
SearchRequest
{
Query
:
"test"
})
require
.
ErrorContains
(
t
,
err
,
"no available provider"
)
}
func
TestManager_SearchWithBestProvider_SkipExpired
(
t
*
testing
.
T
)
{
past
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Hour
)
.
Unix
()
m
:=
NewManager
([]
ProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"k"
,
ExpiresAt
:
&
past
},
},
nil
)
_
,
_
,
err
:=
m
.
SearchWithBestProvider
(
context
.
Background
(),
SearchRequest
{
Query
:
"test"
})
require
.
ErrorContains
(
t
,
err
,
"no available provider"
)
}
func
TestManager_SearchWithBestProvider_UsesFirstAvailable
(
t
*
testing
.
T
)
{
srvBrave
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
_
*
http
.
Request
)
{
resp
:=
braveResponse
{}
resp
.
Web
.
Results
=
[]
braveResult
{{
URL
:
"https://brave.com"
,
Title
:
"Brave"
,
Description
:
"from brave"
}}
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
resp
)
}))
defer
srvBrave
.
Close
()
origURL
:=
*
braveSearchURL
u
,
_
:=
http
.
NewRequest
(
"GET"
,
srvBrave
.
URL
,
nil
)
*
braveSearchURL
=
*
u
.
URL
defer
func
()
{
*
braveSearchURL
=
origURL
}()
m
:=
NewManager
([]
ProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"k1"
},
{
Type
:
"tavily"
,
APIKey
:
"k2"
},
},
nil
)
m
.
clientCache
[
srvBrave
.
URL
]
=
srvBrave
.
Client
()
m
.
clientCache
[
""
]
=
srvBrave
.
Client
()
resp
,
providerName
,
err
:=
m
.
SearchWithBestProvider
(
context
.
Background
(),
SearchRequest
{
Query
:
"test"
})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"brave"
,
providerName
)
require
.
Len
(
t
,
resp
.
Results
,
1
)
require
.
Equal
(
t
,
"from brave"
,
resp
.
Results
[
0
]
.
Snippet
)
}
func
TestManager_SearchWithBestProvider_NilRedis
(
t
*
testing
.
T
)
{
srv
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
_
*
http
.
Request
)
{
resp
:=
braveResponse
{}
resp
.
Web
.
Results
=
[]
braveResult
{{
URL
:
"https://test.com"
,
Title
:
"Test"
,
Description
:
"result"
}}
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
resp
)
}))
defer
srv
.
Close
()
origURL
:=
*
braveSearchURL
u
,
_
:=
http
.
NewRequest
(
"GET"
,
srv
.
URL
,
nil
)
*
braveSearchURL
=
*
u
.
URL
defer
func
()
{
*
braveSearchURL
=
origURL
}()
m
:=
NewManager
([]
ProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"k"
,
QuotaLimit
:
100
},
},
nil
)
m
.
clientCache
[
""
]
=
srv
.
Client
()
resp
,
_
,
err
:=
m
.
SearchWithBestProvider
(
context
.
Background
(),
SearchRequest
{
Query
:
"test"
})
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
resp
.
Results
,
1
)
}
func
TestManager_GetUsage_NilRedis
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
used
,
err
:=
m
.
GetUsage
(
context
.
Background
(),
"brave"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
0
),
used
)
}
func
TestManager_GetAllUsage_NilRedis
(
t
*
testing
.
T
)
{
m
:=
NewManager
([]
ProviderConfig
{
{
Type
:
"brave"
},
},
nil
)
usage
:=
m
.
GetAllUsage
(
context
.
Background
())
require
.
Equal
(
t
,
int64
(
0
),
usage
[
"brave"
])
}
// --- Quota TTL from subscription ---
func
TestQuotaTTLFromSubscription_NilSubscription
(
t
*
testing
.
T
)
{
ttl
:=
quotaTTLFromSubscription
(
nil
)
require
.
Equal
(
t
,
defaultQuotaTTL
,
ttl
)
}
func
TestQuotaTTLFromSubscription_ZeroSubscription
(
t
*
testing
.
T
)
{
zero
:=
int64
(
0
)
ttl
:=
quotaTTLFromSubscription
(
&
zero
)
require
.
Equal
(
t
,
defaultQuotaTTL
,
ttl
)
}
func
TestQuotaTTLFromSubscription_ValidSubscription
(
t
*
testing
.
T
)
{
// Subscribed 10 days ago — next reset in ~20 days
sub
:=
time
.
Now
()
.
Add
(
-
10
*
24
*
time
.
Hour
)
.
Unix
()
ttl
:=
quotaTTLFromSubscription
(
&
sub
)
require
.
Greater
(
t
,
ttl
,
15
*
24
*
time
.
Hour
)
// at least 15 days
require
.
Less
(
t
,
ttl
,
25
*
24
*
time
.
Hour
+
quotaTTLBuffer
)
}
func
TestNextMonthlyReset_SubscribedRecentPast
(
t
*
testing
.
T
)
{
// Subscribed on the 10th of this month (always valid day)
now
:=
time
.
Now
()
.
UTC
()
sub
:=
time
.
Date
(
now
.
Year
(),
now
.
Month
(),
10
,
0
,
0
,
0
,
0
,
time
.
UTC
)
next
:=
nextMonthlyReset
(
sub
)
require
.
True
(
t
,
next
.
After
(
now
)
||
next
.
Equal
(
now
),
"next reset should be in the future or now"
)
require
.
True
(
t
,
next
.
Before
(
now
.
AddDate
(
0
,
1
,
1
)))
}
func
TestNextMonthlyReset_SubscribedLongAgo
(
t
*
testing
.
T
)
{
// Subscribed 6 months ago on the 1st
sub
:=
time
.
Now
()
.
UTC
()
.
AddDate
(
0
,
-
6
,
0
)
sub
=
time
.
Date
(
sub
.
Year
(),
sub
.
Month
(),
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
next
:=
nextMonthlyReset
(
sub
)
require
.
True
(
t
,
next
.
After
(
time
.
Now
()
.
UTC
()))
// Should be within the next 31 days
require
.
True
(
t
,
next
.
Before
(
time
.
Now
()
.
UTC
()
.
AddDate
(
0
,
1
,
1
)))
}
func
TestNextMonthlyReset_FutureSubscription
(
t
*
testing
.
T
)
{
sub
:=
time
.
Now
()
.
UTC
()
.
AddDate
(
0
,
0
,
5
)
next
:=
nextMonthlyReset
(
sub
)
require
.
True
(
t
,
next
.
After
(
time
.
Now
()
.
UTC
()))
}
func
TestAddMonthsClamped_Jan31ToFeb
(
t
*
testing
.
T
)
{
sub
:=
time
.
Date
(
2026
,
1
,
31
,
0
,
0
,
0
,
0
,
time
.
UTC
)
next
:=
addMonthsClamped
(
sub
,
1
)
require
.
Equal
(
t
,
time
.
Month
(
2
),
next
.
Month
())
require
.
Equal
(
t
,
28
,
next
.
Day
())
// Feb 28 (2026 is not a leap year)
}
func
TestAddMonthsClamped_Jan31ToFebLeapYear
(
t
*
testing
.
T
)
{
sub
:=
time
.
Date
(
2028
,
1
,
31
,
0
,
0
,
0
,
0
,
time
.
UTC
)
next
:=
addMonthsClamped
(
sub
,
1
)
require
.
Equal
(
t
,
time
.
Month
(
2
),
next
.
Month
())
require
.
Equal
(
t
,
29
,
next
.
Day
())
// Feb 29 (2028 is a leap year)
}
func
TestAddMonthsClamped_Mar31ToApr
(
t
*
testing
.
T
)
{
sub
:=
time
.
Date
(
2026
,
3
,
31
,
0
,
0
,
0
,
0
,
time
.
UTC
)
next
:=
addMonthsClamped
(
sub
,
1
)
require
.
Equal
(
t
,
time
.
Month
(
4
),
next
.
Month
())
require
.
Equal
(
t
,
30
,
next
.
Day
())
// Apr has 30 days
}
func
TestAddMonthsClamped_NormalDay
(
t
*
testing
.
T
)
{
sub
:=
time
.
Date
(
2026
,
1
,
15
,
0
,
0
,
0
,
0
,
time
.
UTC
)
next
:=
addMonthsClamped
(
sub
,
1
)
require
.
Equal
(
t
,
time
.
Month
(
2
),
next
.
Month
())
require
.
Equal
(
t
,
15
,
next
.
Day
())
// no clamping needed
}
// --- Redis key ---
func
TestQuotaRedisKey_Format
(
t
*
testing
.
T
)
{
key
:=
quotaRedisKey
(
"brave"
)
require
.
Equal
(
t
,
"websearch:quota:brave"
,
key
)
}
// --- isProviderAvailable ---
func
TestIsProviderAvailable_EmptyAPIKey
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
require
.
False
(
t
,
m
.
isProviderAvailable
(
ProviderConfig
{
APIKey
:
""
}))
}
func
TestIsProviderAvailable_Expired
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
past
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Hour
)
.
Unix
()
require
.
False
(
t
,
m
.
isProviderAvailable
(
ProviderConfig
{
APIKey
:
"k"
,
ExpiresAt
:
&
past
}))
}
func
TestIsProviderAvailable_Valid
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
future
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Unix
()
require
.
True
(
t
,
m
.
isProviderAvailable
(
ProviderConfig
{
APIKey
:
"k"
,
ExpiresAt
:
&
future
}))
require
.
True
(
t
,
m
.
isProviderAvailable
(
ProviderConfig
{
APIKey
:
"k"
}))
// no expiry
}
// --- resolveProxyID ---
func
TestResolveProxyID_AccountProxyOverrides
(
t
*
testing
.
T
)
{
cfg
:=
ProviderConfig
{
ProxyID
:
42
}
require
.
Equal
(
t
,
int64
(
0
),
resolveProxyID
(
cfg
,
"http://account-proxy:8080"
))
require
.
Equal
(
t
,
int64
(
42
),
resolveProxyID
(
cfg
,
""
))
}
// --- isProxyError ---
func
TestIsProxyError_Nil
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isProxyError
(
nil
))
}
func
TestIsProxyError_ConnectionRefused
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isProxyError
(
fmt
.
Errorf
(
"dial tcp: connection refused"
)))
}
func
TestIsProxyError_Timeout
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isProxyError
(
fmt
.
Errorf
(
"i/o timeout while connecting to proxy"
)))
}
func
TestIsProxyError_SOCKS
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isProxyError
(
fmt
.
Errorf
(
"socks connect failed"
)))
}
func
TestIsProxyError_TLSHandshake
(
t
*
testing
.
T
)
{
require
.
True
(
t
,
isProxyError
(
fmt
.
Errorf
(
"tls handshake timeout"
)))
}
func
TestIsProxyError_APIError_NotProxy
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isProxyError
(
fmt
.
Errorf
(
"API rate limit exceeded"
)))
}
// --- isProxyAvailable (nil Redis) ---
func
TestIsProxyAvailable_NilRedis
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
require
.
True
(
t
,
m
.
isProxyAvailable
(
context
.
Background
(),
42
))
}
func
TestIsProxyAvailable_ZeroID
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
require
.
True
(
t
,
m
.
isProxyAvailable
(
context
.
Background
(),
0
))
}
// --- selectByQuotaWeight ---
func
TestSelectByQuotaWeight_NoQuotaLast
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
candidates
:=
[]
ProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"k1"
,
QuotaLimit
:
0
},
{
Type
:
"tavily"
,
APIKey
:
"k2"
,
QuotaLimit
:
100
},
}
result
:=
m
.
selectByQuotaWeight
(
context
.
Background
(),
candidates
)
require
.
Len
(
t
,
result
,
2
)
require
.
Equal
(
t
,
"tavily"
,
result
[
0
]
.
Type
)
require
.
Equal
(
t
,
"brave"
,
result
[
1
]
.
Type
)
}
func
TestSelectByQuotaWeight_AllNoQuota
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
candidates
:=
[]
ProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"k1"
,
QuotaLimit
:
0
},
{
Type
:
"tavily"
,
APIKey
:
"k2"
,
QuotaLimit
:
0
},
}
result
:=
m
.
selectByQuotaWeight
(
context
.
Background
(),
candidates
)
require
.
Len
(
t
,
result
,
2
)
}
func
TestSelectByQuotaWeight_Empty
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
result
:=
m
.
selectByQuotaWeight
(
context
.
Background
(),
nil
)
require
.
Empty
(
t
,
result
)
}
// --- newHTTPClient ---
func
TestNewHTTPClient_NoProxy
(
t
*
testing
.
T
)
{
c
,
err
:=
newHTTPClient
(
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
c
)
}
func
TestNewHTTPClient_InvalidProxy
(
t
*
testing
.
T
)
{
_
,
err
:=
newHTTPClient
(
"://bad-url"
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid proxy URL"
)
}
func
TestNewHTTPClient_ValidHTTPProxy
(
t
*
testing
.
T
)
{
c
,
err
:=
newHTTPClient
(
"http://proxy.example.com:8080"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
c
)
}
func
TestNewHTTPClient_ValidSOCKS5Proxy
(
t
*
testing
.
T
)
{
c
,
err
:=
newHTTPClient
(
"socks5://proxy.example.com:1080"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
c
)
}
// --- ResetUsage ---
func
TestManager_ResetUsage_NilRedis
(
t
*
testing
.
T
)
{
m
:=
NewManager
(
nil
,
nil
)
err
:=
m
.
ResetUsage
(
context
.
Background
(),
"brave"
)
require
.
NoError
(
t
,
err
)
}
backend/internal/pkg/websearch/provider.go
0 → 100644
View file @
0b746501
package
websearch
import
"context"
// Provider is the interface every search backend must implement.
type
Provider
interface
{
// Name returns the provider identifier ("brave" or "tavily").
Name
()
string
// Search executes a web search and returns results.
Search
(
ctx
context
.
Context
,
req
SearchRequest
)
(
*
SearchResponse
,
error
)
}
backend/internal/pkg/websearch/tavily.go
0 → 100644
View file @
0b746501
package
websearch
import
(
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
const
(
tavilySearchEndpoint
=
"https://api.tavily.com/search"
tavilyProviderName
=
"tavily"
tavilySearchDepthBasic
=
"basic"
)
// TavilyProvider implements web search via the Tavily Search API.
type
TavilyProvider
struct
{
apiKey
string
httpClient
*
http
.
Client
}
// NewTavilyProvider creates a Tavily Search provider.
// The caller is responsible for configuring the http.Client with proxy/timeouts.
func
NewTavilyProvider
(
apiKey
string
,
httpClient
*
http
.
Client
)
*
TavilyProvider
{
if
httpClient
==
nil
{
httpClient
=
http
.
DefaultClient
}
return
&
TavilyProvider
{
apiKey
:
apiKey
,
httpClient
:
httpClient
}
}
func
(
t
*
TavilyProvider
)
Name
()
string
{
return
tavilyProviderName
}
func
(
t
*
TavilyProvider
)
Search
(
ctx
context
.
Context
,
req
SearchRequest
)
(
*
SearchResponse
,
error
)
{
maxResults
:=
req
.
MaxResults
if
maxResults
<=
0
{
maxResults
=
defaultMaxResults
}
payload
:=
tavilyRequest
{
APIKey
:
t
.
apiKey
,
Query
:
req
.
Query
,
MaxResults
:
maxResults
,
SearchDepth
:
tavilySearchDepthBasic
,
}
bodyBytes
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"tavily: encode request: %w"
,
err
)
}
httpReq
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
tavilySearchEndpoint
,
bytes
.
NewReader
(
bodyBytes
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"tavily: build request: %w"
,
err
)
}
httpReq
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
resp
,
err
:=
t
.
httpClient
.
Do
(
httpReq
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"tavily: request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
maxResponseSize
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"tavily: read body: %w"
,
err
)
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
fmt
.
Errorf
(
"tavily: status %d: %s"
,
resp
.
StatusCode
,
truncateBody
(
body
))
}
var
raw
tavilyResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
raw
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"tavily: decode response: %w"
,
err
)
}
results
:=
make
([]
SearchResult
,
0
,
len
(
raw
.
Results
))
for
_
,
r
:=
range
raw
.
Results
{
results
=
append
(
results
,
SearchResult
{
URL
:
r
.
URL
,
Title
:
r
.
Title
,
Snippet
:
r
.
Content
,
})
}
return
&
SearchResponse
{
Results
:
results
,
Query
:
req
.
Query
},
nil
}
type
tavilyRequest
struct
{
APIKey
string
`json:"api_key"`
Query
string
`json:"query"`
MaxResults
int
`json:"max_results"`
SearchDepth
string
`json:"search_depth"`
}
type
tavilyResponse
struct
{
Results
[]
tavilyResult
`json:"results"`
}
type
tavilyResult
struct
{
URL
string
`json:"url"`
Title
string
`json:"title"`
Content
string
`json:"content"`
Score
float64
`json:"score"`
}
backend/internal/pkg/websearch/tavily_test.go
0 → 100644
View file @
0b746501
package
websearch
import
(
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func
TestTavilyProvider_Name
(
t
*
testing
.
T
)
{
p
:=
NewTavilyProvider
(
"key"
,
nil
)
require
.
Equal
(
t
,
"tavily"
,
p
.
Name
())
}
func
TestTavilyProvider_Search_RequestConstruction
(
t
*
testing
.
T
)
{
// Verify tavilyRequest struct fields map correctly
req
:=
tavilyRequest
{
APIKey
:
"test-key"
,
Query
:
"golang"
,
MaxResults
:
3
,
SearchDepth
:
tavilySearchDepthBasic
,
}
data
,
err
:=
json
.
Marshal
(
req
)
require
.
NoError
(
t
,
err
)
var
parsed
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
data
,
&
parsed
))
require
.
Equal
(
t
,
"test-key"
,
parsed
[
"api_key"
])
require
.
Equal
(
t
,
"golang"
,
parsed
[
"query"
])
require
.
Equal
(
t
,
float64
(
3
),
parsed
[
"max_results"
])
require
.
Equal
(
t
,
"basic"
,
parsed
[
"search_depth"
])
}
func
TestTavilyProvider_Search_ResponseParsing
(
t
*
testing
.
T
)
{
rawResp
:=
`{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
var
resp
tavilyResponse
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
rawResp
),
&
resp
))
require
.
Len
(
t
,
resp
.
Results
,
1
)
require
.
Equal
(
t
,
"https://go.dev"
,
resp
.
Results
[
0
]
.
URL
)
require
.
Equal
(
t
,
"Go programming language"
,
resp
.
Results
[
0
]
.
Content
)
require
.
InDelta
(
t
,
0.95
,
resp
.
Results
[
0
]
.
Score
,
0.001
)
// Verify mapping to SearchResult
results
:=
make
([]
SearchResult
,
0
,
len
(
resp
.
Results
))
for
_
,
r
:=
range
resp
.
Results
{
results
=
append
(
results
,
SearchResult
{
URL
:
r
.
URL
,
Title
:
r
.
Title
,
Snippet
:
r
.
Content
,
})
}
require
.
Equal
(
t
,
"Go programming language"
,
results
[
0
]
.
Snippet
)
require
.
Equal
(
t
,
""
,
results
[
0
]
.
PageAge
)
}
func
TestTavilyProvider_Search_EmptyResults
(
t
*
testing
.
T
)
{
var
resp
tavilyResponse
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
`{"results":[]}`
),
&
resp
))
require
.
Empty
(
t
,
resp
.
Results
)
}
func
TestTavilyProvider_Search_InvalidJSON
(
t
*
testing
.
T
)
{
var
resp
tavilyResponse
require
.
Error
(
t
,
json
.
Unmarshal
([]
byte
(
"not json"
),
&
resp
))
}
backend/internal/pkg/websearch/types.go
0 → 100644
View file @
0b746501
package
websearch
// SearchResult represents a single web search result.
type
SearchResult
struct
{
URL
string
`json:"url"`
Title
string
`json:"title"`
Snippet
string
`json:"snippet"`
PageAge
string
`json:"page_age,omitempty"`
}
// SearchRequest describes a web search to perform.
type
SearchRequest
struct
{
Query
string
MaxResults
int
// defaults to defaultMaxResults if <= 0
ProxyURL
string
// optional HTTP proxy URL
}
// SearchResponse holds the results of a web search.
type
SearchResponse
struct
{
Results
[]
SearchResult
Query
string
// the query that was actually executed
}
const
defaultMaxResults
=
5
// Provider type identifiers.
const
(
ProviderTypeBrave
=
"brave"
ProviderTypeTavily
=
"tavily"
)
backend/internal/repository/api_key_repo.go
View file @
0b746501
...
...
@@ -138,10 +138,17 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
WithUser
(
func
(
q
*
dbent
.
UserQuery
)
{
q
.
Select
(
user
.
FieldID
,
user
.
FieldEmail
,
user
.
FieldUsername
,
user
.
FieldStatus
,
user
.
FieldRole
,
user
.
FieldBalance
,
user
.
FieldConcurrency
,
user
.
FieldBalanceNotifyEnabled
,
user
.
FieldBalanceNotifyThresholdType
,
user
.
FieldBalanceNotifyThreshold
,
user
.
FieldBalanceNotifyExtraEmails
,
user
.
FieldTotalRecharged
,
)
})
.
WithGroup
(
func
(
q
*
dbent
.
GroupQuery
)
{
...
...
@@ -639,7 +646,7 @@ func userEntityToService(u *dbent.User) *service.User {
if
u
==
nil
{
return
nil
}
return
&
service
.
User
{
out
:=
&
service
.
User
{
ID
:
u
.
ID
,
Email
:
u
.
Email
,
Username
:
u
.
Username
,
...
...
@@ -652,9 +659,18 @@ func userEntityToService(u *dbent.User) *service.User {
TotpSecretEncrypted
:
u
.
TotpSecretEncrypted
,
TotpEnabled
:
u
.
TotpEnabled
,
TotpEnabledAt
:
u
.
TotpEnabledAt
,
BalanceNotifyEnabled
:
u
.
BalanceNotifyEnabled
,
BalanceNotifyThresholdType
:
u
.
BalanceNotifyThresholdType
,
BalanceNotifyThreshold
:
u
.
BalanceNotifyThreshold
,
TotalRecharged
:
u
.
TotalRecharged
,
CreatedAt
:
u
.
CreatedAt
,
UpdatedAt
:
u
.
UpdatedAt
,
}
// Parse extra emails JSON (supports both old []string and new []NotifyEmailEntry format)
if
u
.
BalanceNotifyExtraEmails
!=
""
&&
u
.
BalanceNotifyExtraEmails
!=
"[]"
{
out
.
BalanceNotifyExtraEmails
=
service
.
ParseNotifyEmails
(
u
.
BalanceNotifyExtraEmails
)
}
return
out
}
func
groupEntityToService
(
g
*
dbent
.
Group
)
*
service
.
Group
{
...
...
backend/internal/repository/channel_repo.go
View file @
0b746501
...
...
@@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
if
err
!=
nil
{
return
err
}
featuresConfigJSON
,
err
:=
marshalFeaturesConfig
(
channel
.
FeaturesConfig
)
if
err
!=
nil
{
return
err
}
err
=
tx
.
QueryRowContext
(
ctx
,
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models
, features, features_config, apply_pricing_to_account_stats
) VALUES ($1, $2, $3, $4, $5, $6
, $7, $8, $9
)
RETURNING id, created_at, updated_at`
,
channel
.
Name
,
channel
.
Description
,
channel
.
Status
,
modelMappingJSON
,
channel
.
BillingModelSource
,
channel
.
RestrictModels
,
channel
.
Name
,
channel
.
Description
,
channel
.
Status
,
modelMappingJSON
,
channel
.
BillingModelSource
,
channel
.
RestrictModels
,
channel
.
Features
,
featuresConfigJSON
,
channel
.
ApplyPricingToAccountStats
,
)
.
Scan
(
&
channel
.
ID
,
&
channel
.
CreatedAt
,
&
channel
.
UpdatedAt
)
if
err
!=
nil
{
if
isUniqueViolation
(
err
)
{
...
...
@@ -67,17 +71,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
}
}
// 设置账号统计定价规则
if
len
(
channel
.
AccountStatsPricingRules
)
>
0
{
if
err
:=
replaceAccountStatsPricingRulesTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
AccountStatsPricingRules
);
err
!=
nil
{
return
err
}
}
return
nil
})
}
func
(
r
*
channelRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Channel
,
error
)
{
ch
:=
&
service
.
Channel
{}
var
modelMappingJSON
[]
byte
var
modelMappingJSON
,
featuresConfigJSON
[]
byte
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models,
features, features_config, apply_pricing_to_account_stats,
created_at, updated_at
FROM channels WHERE id = $1`
,
id
,
)
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
modelMappingJSON
,
&
ch
.
BillingModelSource
,
&
ch
.
RestrictModels
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
)
)
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
modelMappingJSON
,
&
ch
.
BillingModelSource
,
&
ch
.
RestrictModels
,
&
ch
.
Features
,
&
featuresConfigJSON
,
&
ch
.
ApplyPricingToAccountStats
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
)
if
err
==
sql
.
ErrNoRows
{
return
nil
,
service
.
ErrChannelNotFound
}
...
...
@@ -85,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
return
nil
,
fmt
.
Errorf
(
"get channel: %w"
,
err
)
}
ch
.
ModelMapping
=
unmarshalModelMapping
(
modelMappingJSON
)
ch
.
FeaturesConfig
=
unmarshalFeaturesConfig
(
featuresConfigJSON
)
groupIDs
,
err
:=
r
.
GetGroupIDs
(
ctx
,
id
)
if
err
!=
nil
{
...
...
@@ -98,6 +110,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
}
ch
.
ModelPricing
=
pricing
statsPricingRules
,
err
:=
r
.
loadAccountStatsPricingRules
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
ch
.
AccountStatsPricingRules
=
statsPricingRules
return
ch
,
nil
}
...
...
@@ -107,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
if
err
!=
nil
{
return
err
}
featuresConfigJSON
,
err
:=
marshalFeaturesConfig
(
channel
.
FeaturesConfig
)
if
err
!=
nil
{
return
err
}
result
,
err
:=
tx
.
ExecContext
(
ctx
,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
WHERE id = $
7
`
,
channel
.
Name
,
channel
.
Description
,
channel
.
Status
,
modelMappingJSON
,
channel
.
BillingModelSource
,
channel
.
RestrictModels
,
channel
.
ID
,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6,
features = $7, features_config = $8, apply_pricing_to_account_stats = $9,
updated_at = NOW()
WHERE id = $
10
`
,
channel
.
Name
,
channel
.
Description
,
channel
.
Status
,
modelMappingJSON
,
channel
.
BillingModelSource
,
channel
.
RestrictModels
,
channel
.
Features
,
featuresConfigJSON
,
channel
.
ApplyPricingToAccountStats
,
channel
.
ID
,
)
if
err
!=
nil
{
if
isUniqueViolation
(
err
)
{
...
...
@@ -137,6 +159,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
}
}
// 更新账号统计定价规则
if
channel
.
AccountStatsPricingRules
!=
nil
{
if
err
:=
replaceAccountStatsPricingRulesTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
AccountStatsPricingRules
);
err
!=
nil
{
return
err
}
}
return
nil
})
}
...
...
@@ -187,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery
:=
fmt
.
Sprintf
(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models,
c.features, c.features_config, c.apply_pricing_to_account_stats,
c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`
,
whereClause
,
channelListOrderBy
(
params
),
argIdx
,
argIdx
+
1
,
)
...
...
@@ -203,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var
channelIDs
[]
int64
for
rows
.
Next
()
{
var
ch
service
.
Channel
var
modelMappingJSON
[]
byte
if
err
:=
rows
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
modelMappingJSON
,
&
ch
.
BillingModelSource
,
&
ch
.
RestrictModels
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
);
err
!=
nil
{
var
modelMappingJSON
,
featuresConfigJSON
[]
byte
if
err
:=
rows
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
modelMappingJSON
,
&
ch
.
BillingModelSource
,
&
ch
.
RestrictModels
,
&
ch
.
Features
,
&
featuresConfigJSON
,
&
ch
.
ApplyPricingToAccountStats
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"scan channel: %w"
,
err
)
}
ch
.
ModelMapping
=
unmarshalModelMapping
(
modelMappingJSON
)
ch
.
FeaturesConfig
=
unmarshalFeaturesConfig
(
featuresConfigJSON
)
channels
=
append
(
channels
,
ch
)
channelIDs
=
append
(
channelIDs
,
ch
.
ID
)
}
...
...
@@ -225,9 +255,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
if
err
!=
nil
{
return
nil
,
nil
,
err
}
statsRulesMap
,
err
:=
r
.
batchLoadAccountStatsPricingRules
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
for
i
:=
range
channels
{
channels
[
i
]
.
GroupIDs
=
groupMap
[
channels
[
i
]
.
ID
]
channels
[
i
]
.
ModelPricing
=
pricingMap
[
channels
[
i
]
.
ID
]
channels
[
i
]
.
AccountStatsPricingRules
=
statsRulesMap
[
channels
[
i
]
.
ID
]
}
}
...
...
@@ -273,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
func
(
r
*
channelRepository
)
ListAll
(
ctx
context
.
Context
)
([]
service
.
Channel
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`
,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models,
features, features_config, apply_pricing_to_account_stats,
created_at, updated_at FROM channels ORDER BY id`
,
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query all channels: %w"
,
err
)
...
...
@@ -284,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var
channelIDs
[]
int64
for
rows
.
Next
()
{
var
ch
service
.
Channel
var
modelMappingJSON
[]
byte
if
err
:=
rows
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
modelMappingJSON
,
&
ch
.
BillingModelSource
,
&
ch
.
RestrictModels
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
);
err
!=
nil
{
var
modelMappingJSON
,
featuresConfigJSON
[]
byte
if
err
:=
rows
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
modelMappingJSON
,
&
ch
.
BillingModelSource
,
&
ch
.
RestrictModels
,
&
ch
.
Features
,
&
featuresConfigJSON
,
&
ch
.
ApplyPricingToAccountStats
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan channel: %w"
,
err
)
}
ch
.
ModelMapping
=
unmarshalModelMapping
(
modelMappingJSON
)
ch
.
FeaturesConfig
=
unmarshalFeaturesConfig
(
featuresConfigJSON
)
channels
=
append
(
channels
,
ch
)
channelIDs
=
append
(
channelIDs
,
ch
.
ID
)
}
...
...
@@ -312,9 +348,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
return
nil
,
err
}
// 批量加载账号统计定价规则
statsRulesMap
,
err
:=
r
.
batchLoadAccountStatsPricingRules
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
err
}
for
i
:=
range
channels
{
channels
[
i
]
.
GroupIDs
=
groupMap
[
channels
[
i
]
.
ID
]
channels
[
i
]
.
ModelPricing
=
pricingMap
[
channels
[
i
]
.
ID
]
channels
[
i
]
.
AccountStatsPricingRules
=
statsRulesMap
[
channels
[
i
]
.
ID
]
}
return
channels
,
nil
...
...
@@ -456,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
return
m
}
func
marshalFeaturesConfig
(
m
map
[
string
]
any
)
([]
byte
,
error
)
{
if
len
(
m
)
==
0
{
return
[]
byte
(
"{}"
),
nil
}
data
,
err
:=
json
.
Marshal
(
m
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"marshal features_config: %w"
,
err
)
}
return
data
,
nil
}
func
unmarshalFeaturesConfig
(
data
[]
byte
)
map
[
string
]
any
{
if
len
(
data
)
==
0
{
return
nil
}
var
m
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
data
,
&
m
);
err
!=
nil
{
return
nil
}
return
m
}
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func
(
r
*
channelRepository
)
GetGroupPlatforms
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
(
map
[
int64
]
string
,
error
)
{
if
len
(
groupIDs
)
==
0
{
...
...
backend/internal/repository/channel_repo_account_stats_pricing.go
0 → 100644
View file @
0b746501
This diff is collapsed.
Click to expand it.
backend/internal/repository/dashboard_aggregation_repo.go
View file @
0b746501
This diff is collapsed.
Click to expand it.
backend/internal/repository/email_cache.go
View file @
0b746501
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment