Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
2b70d1d3
Commit
2b70d1d3
authored
Apr 09, 2026
by
IanShaw027
Browse files
merge upstream main into fix/bug-cleanup-main
parents
b37afd68
00c08c57
Changes
60
Show whitespace changes
Inline
Side-by-side
backend/ent/group.go
View file @
2b70d1d3
...
@@ -11,6 +11,7 @@ import (
...
@@ -11,6 +11,7 @@ import (
"entgo.io/ent"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
)
// Group is the model entity for the Group schema.
// Group is the model entity for the Group schema.
...
@@ -76,6 +77,8 @@ type Group struct {
...
@@ -76,6 +77,8 @@ type Group struct {
RequirePrivacySet
bool
`json:"require_privacy_set,omitempty"`
RequirePrivacySet
bool
`json:"require_privacy_set,omitempty"`
// 默认映射模型 ID,当账号级映射找不到时使用此值
// 默认映射模型 ID,当账号级映射找不到时使用此值
DefaultMappedModel
string
`json:"default_mapped_model,omitempty"`
DefaultMappedModel
string
`json:"default_mapped_model,omitempty"`
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
MessagesDispatchModelConfig
domain
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges
GroupEdges
`json:"edges"`
Edges
GroupEdges
`json:"edges"`
...
@@ -182,7 +185,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
...
@@ -182,7 +185,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values
:=
make
([]
any
,
len
(
columns
))
values
:=
make
([]
any
,
len
(
columns
))
for
i
:=
range
columns
{
for
i
:=
range
columns
{
switch
columns
[
i
]
{
switch
columns
[
i
]
{
case
group
.
FieldModelRouting
,
group
.
FieldSupportedModelScopes
:
case
group
.
FieldModelRouting
,
group
.
FieldSupportedModelScopes
,
group
.
FieldMessagesDispatchModelConfig
:
values
[
i
]
=
new
([]
byte
)
values
[
i
]
=
new
([]
byte
)
case
group
.
FieldIsExclusive
,
group
.
FieldClaudeCodeOnly
,
group
.
FieldModelRoutingEnabled
,
group
.
FieldMcpXMLInject
,
group
.
FieldAllowMessagesDispatch
,
group
.
FieldRequireOauthOnly
,
group
.
FieldRequirePrivacySet
:
case
group
.
FieldIsExclusive
,
group
.
FieldClaudeCodeOnly
,
group
.
FieldModelRoutingEnabled
,
group
.
FieldMcpXMLInject
,
group
.
FieldAllowMessagesDispatch
,
group
.
FieldRequireOauthOnly
,
group
.
FieldRequirePrivacySet
:
values
[
i
]
=
new
(
sql
.
NullBool
)
values
[
i
]
=
new
(
sql
.
NullBool
)
...
@@ -403,6 +406,14 @@ func (_m *Group) assignValues(columns []string, values []any) error {
...
@@ -403,6 +406,14 @@ func (_m *Group) assignValues(columns []string, values []any) error {
}
else
if
value
.
Valid
{
}
else
if
value
.
Valid
{
_m
.
DefaultMappedModel
=
value
.
String
_m
.
DefaultMappedModel
=
value
.
String
}
}
case
group
.
FieldMessagesDispatchModelConfig
:
if
value
,
ok
:=
values
[
i
]
.
(
*
[]
byte
);
!
ok
{
return
fmt
.
Errorf
(
"unexpected type %T for field messages_dispatch_model_config"
,
values
[
i
])
}
else
if
value
!=
nil
&&
len
(
*
value
)
>
0
{
if
err
:=
json
.
Unmarshal
(
*
value
,
&
_m
.
MessagesDispatchModelConfig
);
err
!=
nil
{
return
fmt
.
Errorf
(
"unmarshal field messages_dispatch_model_config: %w"
,
err
)
}
}
default
:
default
:
_m
.
selectValues
.
Set
(
columns
[
i
],
values
[
i
])
_m
.
selectValues
.
Set
(
columns
[
i
],
values
[
i
])
}
}
...
@@ -585,6 +596,9 @@ func (_m *Group) String() string {
...
@@ -585,6 +596,9 @@ func (_m *Group) String() string {
builder
.
WriteString
(
", "
)
builder
.
WriteString
(
", "
)
builder
.
WriteString
(
"default_mapped_model="
)
builder
.
WriteString
(
"default_mapped_model="
)
builder
.
WriteString
(
_m
.
DefaultMappedModel
)
builder
.
WriteString
(
_m
.
DefaultMappedModel
)
builder
.
WriteString
(
", "
)
builder
.
WriteString
(
"messages_dispatch_model_config="
)
builder
.
WriteString
(
fmt
.
Sprintf
(
"%v"
,
_m
.
MessagesDispatchModelConfig
))
builder
.
WriteByte
(
')'
)
builder
.
WriteByte
(
')'
)
return
builder
.
String
()
return
builder
.
String
()
}
}
...
...
backend/ent/group/group.go
View file @
2b70d1d3
...
@@ -8,6 +8,7 @@ import (
...
@@ -8,6 +8,7 @@ import (
"entgo.io/ent"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
)
const
(
const
(
...
@@ -73,6 +74,8 @@ const (
...
@@ -73,6 +74,8 @@ const (
FieldRequirePrivacySet
=
"require_privacy_set"
FieldRequirePrivacySet
=
"require_privacy_set"
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
FieldDefaultMappedModel
=
"default_mapped_model"
FieldDefaultMappedModel
=
"default_mapped_model"
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
FieldMessagesDispatchModelConfig
=
"messages_dispatch_model_config"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys
=
"api_keys"
EdgeAPIKeys
=
"api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
...
@@ -177,6 +180,7 @@ var Columns = []string{
...
@@ -177,6 +180,7 @@ var Columns = []string{
FieldRequireOauthOnly
,
FieldRequireOauthOnly
,
FieldRequirePrivacySet
,
FieldRequirePrivacySet
,
FieldDefaultMappedModel
,
FieldDefaultMappedModel
,
FieldMessagesDispatchModelConfig
,
}
}
var
(
var
(
...
@@ -252,6 +256,8 @@ var (
...
@@ -252,6 +256,8 @@ var (
DefaultDefaultMappedModel
string
DefaultDefaultMappedModel
string
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
DefaultMappedModelValidator
func
(
string
)
error
DefaultMappedModelValidator
func
(
string
)
error
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
DefaultMessagesDispatchModelConfig
domain
.
OpenAIMessagesDispatchModelConfig
)
)
// OrderOption defines the ordering options for the Group queries.
// OrderOption defines the ordering options for the Group queries.
...
...
backend/ent/group_create.go
View file @
2b70d1d3
...
@@ -18,6 +18,7 @@ import (
...
@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
)
// GroupCreate is the builder for creating a Group entity.
// GroupCreate is the builder for creating a Group entity.
...
@@ -410,6 +411,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
...
@@ -410,6 +411,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
return
_c
return
_c
}
}
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func
(
_c
*
GroupCreate
)
SetMessagesDispatchModelConfig
(
v
domain
.
OpenAIMessagesDispatchModelConfig
)
*
GroupCreate
{
_c
.
mutation
.
SetMessagesDispatchModelConfig
(
v
)
return
_c
}
// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
func
(
_c
*
GroupCreate
)
SetNillableMessagesDispatchModelConfig
(
v
*
domain
.
OpenAIMessagesDispatchModelConfig
)
*
GroupCreate
{
if
v
!=
nil
{
_c
.
SetMessagesDispatchModelConfig
(
*
v
)
}
return
_c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func
(
_c
*
GroupCreate
)
AddAPIKeyIDs
(
ids
...
int64
)
*
GroupCreate
{
func
(
_c
*
GroupCreate
)
AddAPIKeyIDs
(
ids
...
int64
)
*
GroupCreate
{
_c
.
mutation
.
AddAPIKeyIDs
(
ids
...
)
_c
.
mutation
.
AddAPIKeyIDs
(
ids
...
)
...
@@ -611,6 +626,10 @@ func (_c *GroupCreate) defaults() error {
...
@@ -611,6 +626,10 @@ func (_c *GroupCreate) defaults() error {
v
:=
group
.
DefaultDefaultMappedModel
v
:=
group
.
DefaultDefaultMappedModel
_c
.
mutation
.
SetDefaultMappedModel
(
v
)
_c
.
mutation
.
SetDefaultMappedModel
(
v
)
}
}
if
_
,
ok
:=
_c
.
mutation
.
MessagesDispatchModelConfig
();
!
ok
{
v
:=
group
.
DefaultMessagesDispatchModelConfig
_c
.
mutation
.
SetMessagesDispatchModelConfig
(
v
)
}
return
nil
return
nil
}
}
...
@@ -695,6 +714,9 @@ func (_c *GroupCreate) check() error {
...
@@ -695,6 +714,9 @@ func (_c *GroupCreate) check() error {
return
&
ValidationError
{
Name
:
"default_mapped_model"
,
err
:
fmt
.
Errorf
(
`ent: validator failed for field "Group.default_mapped_model": %w`
,
err
)}
return
&
ValidationError
{
Name
:
"default_mapped_model"
,
err
:
fmt
.
Errorf
(
`ent: validator failed for field "Group.default_mapped_model": %w`
,
err
)}
}
}
}
}
if
_
,
ok
:=
_c
.
mutation
.
MessagesDispatchModelConfig
();
!
ok
{
return
&
ValidationError
{
Name
:
"messages_dispatch_model_config"
,
err
:
errors
.
New
(
`ent: missing required field "Group.messages_dispatch_model_config"`
)}
}
return
nil
return
nil
}
}
...
@@ -838,6 +860,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
...
@@ -838,6 +860,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec
.
SetField
(
group
.
FieldDefaultMappedModel
,
field
.
TypeString
,
value
)
_spec
.
SetField
(
group
.
FieldDefaultMappedModel
,
field
.
TypeString
,
value
)
_node
.
DefaultMappedModel
=
value
_node
.
DefaultMappedModel
=
value
}
}
if
value
,
ok
:=
_c
.
mutation
.
MessagesDispatchModelConfig
();
ok
{
_spec
.
SetField
(
group
.
FieldMessagesDispatchModelConfig
,
field
.
TypeJSON
,
value
)
_node
.
MessagesDispatchModelConfig
=
value
}
if
nodes
:=
_c
.
mutation
.
APIKeysIDs
();
len
(
nodes
)
>
0
{
if
nodes
:=
_c
.
mutation
.
APIKeysIDs
();
len
(
nodes
)
>
0
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Rel
:
sqlgraph
.
O2M
,
...
@@ -1462,6 +1488,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
...
@@ -1462,6 +1488,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
return
u
return
u
}
}
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func
(
u
*
GroupUpsert
)
SetMessagesDispatchModelConfig
(
v
domain
.
OpenAIMessagesDispatchModelConfig
)
*
GroupUpsert
{
u
.
Set
(
group
.
FieldMessagesDispatchModelConfig
,
v
)
return
u
}
// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
func
(
u
*
GroupUpsert
)
UpdateMessagesDispatchModelConfig
()
*
GroupUpsert
{
u
.
SetExcluded
(
group
.
FieldMessagesDispatchModelConfig
)
return
u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
// Using this option is equivalent to using:
//
//
...
@@ -2053,6 +2091,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
...
@@ -2053,6 +2091,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
})
})
}
}
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func
(
u
*
GroupUpsertOne
)
SetMessagesDispatchModelConfig
(
v
domain
.
OpenAIMessagesDispatchModelConfig
)
*
GroupUpsertOne
{
return
u
.
Update
(
func
(
s
*
GroupUpsert
)
{
s
.
SetMessagesDispatchModelConfig
(
v
)
})
}
// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
func
(
u
*
GroupUpsertOne
)
UpdateMessagesDispatchModelConfig
()
*
GroupUpsertOne
{
return
u
.
Update
(
func
(
s
*
GroupUpsert
)
{
s
.
UpdateMessagesDispatchModelConfig
()
})
}
// Exec executes the query.
// Exec executes the query.
func
(
u
*
GroupUpsertOne
)
Exec
(
ctx
context
.
Context
)
error
{
func
(
u
*
GroupUpsertOne
)
Exec
(
ctx
context
.
Context
)
error
{
if
len
(
u
.
create
.
conflict
)
==
0
{
if
len
(
u
.
create
.
conflict
)
==
0
{
...
@@ -2810,6 +2862,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
...
@@ -2810,6 +2862,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
})
})
}
}
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func
(
u
*
GroupUpsertBulk
)
SetMessagesDispatchModelConfig
(
v
domain
.
OpenAIMessagesDispatchModelConfig
)
*
GroupUpsertBulk
{
return
u
.
Update
(
func
(
s
*
GroupUpsert
)
{
s
.
SetMessagesDispatchModelConfig
(
v
)
})
}
// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
func
(
u
*
GroupUpsertBulk
)
UpdateMessagesDispatchModelConfig
()
*
GroupUpsertBulk
{
return
u
.
Update
(
func
(
s
*
GroupUpsert
)
{
s
.
UpdateMessagesDispatchModelConfig
()
})
}
// Exec executes the query.
// Exec executes the query.
func
(
u
*
GroupUpsertBulk
)
Exec
(
ctx
context
.
Context
)
error
{
func
(
u
*
GroupUpsertBulk
)
Exec
(
ctx
context
.
Context
)
error
{
if
u
.
create
.
err
!=
nil
{
if
u
.
create
.
err
!=
nil
{
...
...
backend/ent/group_update.go
View file @
2b70d1d3
...
@@ -20,6 +20,7 @@ import (
...
@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
)
// GroupUpdate is the builder for updating Group entities.
// GroupUpdate is the builder for updating Group entities.
...
@@ -552,6 +553,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
...
@@ -552,6 +553,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
return
_u
return
_u
}
}
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func
(
_u
*
GroupUpdate
)
SetMessagesDispatchModelConfig
(
v
domain
.
OpenAIMessagesDispatchModelConfig
)
*
GroupUpdate
{
_u
.
mutation
.
SetMessagesDispatchModelConfig
(
v
)
return
_u
}
// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
func
(
_u
*
GroupUpdate
)
SetNillableMessagesDispatchModelConfig
(
v
*
domain
.
OpenAIMessagesDispatchModelConfig
)
*
GroupUpdate
{
if
v
!=
nil
{
_u
.
SetMessagesDispatchModelConfig
(
*
v
)
}
return
_u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func
(
_u
*
GroupUpdate
)
AddAPIKeyIDs
(
ids
...
int64
)
*
GroupUpdate
{
func
(
_u
*
GroupUpdate
)
AddAPIKeyIDs
(
ids
...
int64
)
*
GroupUpdate
{
_u
.
mutation
.
AddAPIKeyIDs
(
ids
...
)
_u
.
mutation
.
AddAPIKeyIDs
(
ids
...
)
...
@@ -1012,6 +1027,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
...
@@ -1012,6 +1027,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if
value
,
ok
:=
_u
.
mutation
.
DefaultMappedModel
();
ok
{
if
value
,
ok
:=
_u
.
mutation
.
DefaultMappedModel
();
ok
{
_spec
.
SetField
(
group
.
FieldDefaultMappedModel
,
field
.
TypeString
,
value
)
_spec
.
SetField
(
group
.
FieldDefaultMappedModel
,
field
.
TypeString
,
value
)
}
}
if
value
,
ok
:=
_u
.
mutation
.
MessagesDispatchModelConfig
();
ok
{
_spec
.
SetField
(
group
.
FieldMessagesDispatchModelConfig
,
field
.
TypeJSON
,
value
)
}
if
_u
.
mutation
.
APIKeysCleared
()
{
if
_u
.
mutation
.
APIKeysCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Rel
:
sqlgraph
.
O2M
,
...
@@ -1843,6 +1861,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
...
@@ -1843,6 +1861,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
return
_u
return
_u
}
}
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func
(
_u
*
GroupUpdateOne
)
SetMessagesDispatchModelConfig
(
v
domain
.
OpenAIMessagesDispatchModelConfig
)
*
GroupUpdateOne
{
_u
.
mutation
.
SetMessagesDispatchModelConfig
(
v
)
return
_u
}
// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
func
(
_u
*
GroupUpdateOne
)
SetNillableMessagesDispatchModelConfig
(
v
*
domain
.
OpenAIMessagesDispatchModelConfig
)
*
GroupUpdateOne
{
if
v
!=
nil
{
_u
.
SetMessagesDispatchModelConfig
(
*
v
)
}
return
_u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func
(
_u
*
GroupUpdateOne
)
AddAPIKeyIDs
(
ids
...
int64
)
*
GroupUpdateOne
{
func
(
_u
*
GroupUpdateOne
)
AddAPIKeyIDs
(
ids
...
int64
)
*
GroupUpdateOne
{
_u
.
mutation
.
AddAPIKeyIDs
(
ids
...
)
_u
.
mutation
.
AddAPIKeyIDs
(
ids
...
)
...
@@ -2333,6 +2365,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
...
@@ -2333,6 +2365,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if
value
,
ok
:=
_u
.
mutation
.
DefaultMappedModel
();
ok
{
if
value
,
ok
:=
_u
.
mutation
.
DefaultMappedModel
();
ok
{
_spec
.
SetField
(
group
.
FieldDefaultMappedModel
,
field
.
TypeString
,
value
)
_spec
.
SetField
(
group
.
FieldDefaultMappedModel
,
field
.
TypeString
,
value
)
}
}
if
value
,
ok
:=
_u
.
mutation
.
MessagesDispatchModelConfig
();
ok
{
_spec
.
SetField
(
group
.
FieldMessagesDispatchModelConfig
,
field
.
TypeJSON
,
value
)
}
if
_u
.
mutation
.
APIKeysCleared
()
{
if
_u
.
mutation
.
APIKeysCleared
()
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
edge
:=
&
sqlgraph
.
EdgeSpec
{
Rel
:
sqlgraph
.
O2M
,
Rel
:
sqlgraph
.
O2M
,
...
...
backend/ent/migrate/schema.go
View file @
2b70d1d3
...
@@ -407,6 +407,7 @@ var (
...
@@ -407,6 +407,7 @@ var (
{
Name
:
"require_oauth_only"
,
Type
:
field
.
TypeBool
,
Default
:
false
},
{
Name
:
"require_oauth_only"
,
Type
:
field
.
TypeBool
,
Default
:
false
},
{
Name
:
"require_privacy_set"
,
Type
:
field
.
TypeBool
,
Default
:
false
},
{
Name
:
"require_privacy_set"
,
Type
:
field
.
TypeBool
,
Default
:
false
},
{
Name
:
"default_mapped_model"
,
Type
:
field
.
TypeString
,
Size
:
100
,
Default
:
""
},
{
Name
:
"default_mapped_model"
,
Type
:
field
.
TypeString
,
Size
:
100
,
Default
:
""
},
{
Name
:
"messages_dispatch_model_config"
,
Type
:
field
.
TypeJSON
,
SchemaType
:
map
[
string
]
string
{
"postgres"
:
"jsonb"
}},
}
}
// GroupsTable holds the schema information for the "groups" table.
// GroupsTable holds the schema information for the "groups" table.
GroupsTable
=
&
schema
.
Table
{
GroupsTable
=
&
schema
.
Table
{
...
...
backend/ent/mutation.go
View file @
2b70d1d3
...
@@ -8246,6 +8246,7 @@ type GroupMutation struct {
...
@@ -8246,6 +8246,7 @@ type GroupMutation struct {
require_oauth_only *bool
require_oauth_only *bool
require_privacy_set *bool
require_privacy_set *bool
default_mapped_model *string
default_mapped_model *string
messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
clearedFields map[string]struct{}
clearedFields map[string]struct{}
api_keys map[int64]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
...
@@ -9798,6 +9799,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
...
@@ -9798,6 +9799,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
m.default_mapped_model = nil
m.default_mapped_model = nil
}
}
// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) {
m.messages_dispatch_model_config = &damdmc
}
// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation.
func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) {
v := m.messages_dispatch_model_config
if v == nil {
return
}
return *v, true
}
// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err)
}
return oldValue.MessagesDispatchModelConfig, nil
}
// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field.
func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
m.messages_dispatch_model_config = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
if m.api_keys == nil {
...
@@ -10156,7 +10193,7 @@ func (m *GroupMutation) Type() string {
...
@@ -10156,7 +10193,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
// AddedFields().
func (m *GroupMutation) Fields() []string {
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0,
29
)
fields := make([]string, 0,
30
)
if m.created_at != nil {
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
fields = append(fields, group.FieldCreatedAt)
}
}
...
@@ -10244,6 +10281,9 @@ func (m *GroupMutation) Fields() []string {
...
@@ -10244,6 +10281,9 @@ func (m *GroupMutation) Fields() []string {
if m.default_mapped_model != nil {
if m.default_mapped_model != nil {
fields = append(fields, group.FieldDefaultMappedModel)
fields = append(fields, group.FieldDefaultMappedModel)
}
}
if m.messages_dispatch_model_config != nil {
fields = append(fields, group.FieldMessagesDispatchModelConfig)
}
return fields
return fields
}
}
...
@@ -10310,6 +10350,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
...
@@ -10310,6 +10350,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.RequirePrivacySet()
return m.RequirePrivacySet()
case group.FieldDefaultMappedModel:
case group.FieldDefaultMappedModel:
return m.DefaultMappedModel()
return m.DefaultMappedModel()
case group.FieldMessagesDispatchModelConfig:
return m.MessagesDispatchModelConfig()
}
}
return nil, false
return nil, false
}
}
...
@@ -10377,6 +10419,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
...
@@ -10377,6 +10419,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldRequirePrivacySet(ctx)
return m.OldRequirePrivacySet(ctx)
case group.FieldDefaultMappedModel:
case group.FieldDefaultMappedModel:
return m.OldDefaultMappedModel(ctx)
return m.OldDefaultMappedModel(ctx)
case group.FieldMessagesDispatchModelConfig:
return m.OldMessagesDispatchModelConfig(ctx)
}
}
return nil, fmt.Errorf("unknown Group field %s", name)
return nil, fmt.Errorf("unknown Group field %s", name)
}
}
...
@@ -10589,6 +10633,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
...
@@ -10589,6 +10633,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
}
m.SetDefaultMappedModel(v)
m.SetDefaultMappedModel(v)
return nil
return nil
case group.FieldMessagesDispatchModelConfig:
v, ok := value.(domain.OpenAIMessagesDispatchModelConfig)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetMessagesDispatchModelConfig(v)
return nil
}
}
return fmt.Errorf("unknown Group field %s", name)
return fmt.Errorf("unknown Group field %s", name)
}
}
...
@@ -10929,6 +10980,9 @@ func (m *GroupMutation) ResetField(name string) error {
...
@@ -10929,6 +10980,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldDefaultMappedModel:
case group.FieldDefaultMappedModel:
m.ResetDefaultMappedModel()
m.ResetDefaultMappedModel()
return nil
return nil
case group.FieldMessagesDispatchModelConfig:
m.ResetMessagesDispatchModelConfig()
return nil
}
}
return fmt.Errorf("unknown Group field %s", name)
return fmt.Errorf("unknown Group field %s", name)
}
}
...
...
backend/ent/runtime/runtime.go
View file @
2b70d1d3
...
@@ -28,6 +28,7 @@ import (
...
@@ -28,6 +28,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
)
// The init function reads all schema descriptors with runtime code
// The init function reads all schema descriptors with runtime code
...
@@ -468,6 +469,10 @@ func init() {
...
@@ -468,6 +469,10 @@ func init() {
group
.
DefaultDefaultMappedModel
=
groupDescDefaultMappedModel
.
Default
.
(
string
)
group
.
DefaultDefaultMappedModel
=
groupDescDefaultMappedModel
.
Default
.
(
string
)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group
.
DefaultMappedModelValidator
=
groupDescDefaultMappedModel
.
Validators
[
0
]
.
(
func
(
string
)
error
)
group
.
DefaultMappedModelValidator
=
groupDescDefaultMappedModel
.
Validators
[
0
]
.
(
func
(
string
)
error
)
// groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field.
groupDescMessagesDispatchModelConfig
:=
groupFields
[
26
]
.
Descriptor
()
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
group
.
DefaultMessagesDispatchModelConfig
=
groupDescMessagesDispatchModelConfig
.
Default
.
(
domain
.
OpenAIMessagesDispatchModelConfig
)
idempotencyrecordMixin
:=
schema
.
IdempotencyRecord
{}
.
Mixin
()
idempotencyrecordMixin
:=
schema
.
IdempotencyRecord
{}
.
Mixin
()
idempotencyrecordMixinFields0
:=
idempotencyrecordMixin
[
0
]
.
Fields
()
idempotencyrecordMixinFields0
:=
idempotencyrecordMixin
[
0
]
.
Fields
()
_
=
idempotencyrecordMixinFields0
_
=
idempotencyrecordMixinFields0
...
...
backend/ent/schema/group.go
View file @
2b70d1d3
...
@@ -141,6 +141,10 @@ func (Group) Fields() []ent.Field {
...
@@ -141,6 +141,10 @@ func (Group) Fields() []ent.Field {
MaxLen
(
100
)
.
MaxLen
(
100
)
.
Default
(
""
)
.
Default
(
""
)
.
Comment
(
"默认映射模型 ID,当账号级映射找不到时使用此值"
),
Comment
(
"默认映射模型 ID,当账号级映射找不到时使用此值"
),
field
.
JSON
(
"messages_dispatch_model_config"
,
domain
.
OpenAIMessagesDispatchModelConfig
{})
.
Default
(
domain
.
OpenAIMessagesDispatchModelConfig
{})
.
SchemaType
(
map
[
string
]
string
{
dialect
.
Postgres
:
"jsonb"
})
.
Comment
(
"OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"
),
}
}
}
}
...
...
backend/internal/config/config.go
View file @
2b70d1d3
...
@@ -65,6 +65,7 @@ type Config struct {
...
@@ -65,6 +65,7 @@ type Config struct {
JWT
JWTConfig
`mapstructure:"jwt"`
JWT
JWTConfig
`mapstructure:"jwt"`
Totp
TotpConfig
`mapstructure:"totp"`
Totp
TotpConfig
`mapstructure:"totp"`
LinuxDo
LinuxDoConnectConfig
`mapstructure:"linuxdo_connect"`
LinuxDo
LinuxDoConnectConfig
`mapstructure:"linuxdo_connect"`
OIDC
OIDCConnectConfig
`mapstructure:"oidc_connect"`
Default
DefaultConfig
`mapstructure:"default"`
Default
DefaultConfig
`mapstructure:"default"`
RateLimit
RateLimitConfig
`mapstructure:"rate_limit"`
RateLimit
RateLimitConfig
`mapstructure:"rate_limit"`
Pricing
PricingConfig
`mapstructure:"pricing"`
Pricing
PricingConfig
`mapstructure:"pricing"`
...
@@ -184,6 +185,34 @@ type LinuxDoConnectConfig struct {
...
@@ -184,6 +185,34 @@ type LinuxDoConnectConfig struct {
UserInfoUsernamePath
string
`mapstructure:"userinfo_username_path"`
UserInfoUsernamePath
string
`mapstructure:"userinfo_username_path"`
}
}
type
OIDCConnectConfig
struct
{
Enabled
bool
`mapstructure:"enabled"`
ProviderName
string
`mapstructure:"provider_name"`
// 显示名: "Keycloak" 等
ClientID
string
`mapstructure:"client_id"`
ClientSecret
string
`mapstructure:"client_secret"`
IssuerURL
string
`mapstructure:"issuer_url"`
DiscoveryURL
string
`mapstructure:"discovery_url"`
AuthorizeURL
string
`mapstructure:"authorize_url"`
TokenURL
string
`mapstructure:"token_url"`
UserInfoURL
string
`mapstructure:"userinfo_url"`
JWKSURL
string
`mapstructure:"jwks_url"`
Scopes
string
`mapstructure:"scopes"`
// 默认 "openid email profile"
RedirectURL
string
`mapstructure:"redirect_url"`
// 后端回调地址(需在提供方后台登记)
FrontendRedirectURL
string
`mapstructure:"frontend_redirect_url"`
// 前端接收 token 的路由(默认:/auth/oidc/callback)
TokenAuthMethod
string
`mapstructure:"token_auth_method"`
// client_secret_post / client_secret_basic / none
UsePKCE
bool
`mapstructure:"use_pkce"`
ValidateIDToken
bool
`mapstructure:"validate_id_token"`
AllowedSigningAlgs
string
`mapstructure:"allowed_signing_algs"`
// 默认 "RS256,ES256,PS256"
ClockSkewSeconds
int
`mapstructure:"clock_skew_seconds"`
// 默认 120
RequireEmailVerified
bool
`mapstructure:"require_email_verified"`
// 默认 false
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath
string
`mapstructure:"userinfo_email_path"`
UserInfoIDPath
string
`mapstructure:"userinfo_id_path"`
UserInfoUsernamePath
string
`mapstructure:"userinfo_username_path"`
}
// TokenRefreshConfig OAuth token自动刷新配置
// TokenRefreshConfig OAuth token自动刷新配置
type
TokenRefreshConfig
struct
{
type
TokenRefreshConfig
struct
{
// 是否启用自动刷新
// 是否启用自动刷新
...
@@ -318,6 +347,12 @@ type GatewayConfig struct {
...
@@ -318,6 +347,12 @@ type GatewayConfig struct {
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI
bool
`mapstructure:"force_codex_cli"`
ForceCodexCLI
bool
`mapstructure:"force_codex_cli"`
// ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。
// 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。
ForcedCodexInstructionsTemplateFile
string
`mapstructure:"forced_codex_instructions_template_file"`
// ForcedCodexInstructionsTemplate: 启动时从模板文件读取并缓存的模板内容。
// 该字段不直接参与配置反序列化,仅用于请求热路径避免重复读盘。
ForcedCodexInstructionsTemplate
string
`mapstructure:"-"`
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders
bool
`mapstructure:"openai_passthrough_allow_timeout_headers"`
OpenAIPassthroughAllowTimeoutHeaders
bool
`mapstructure:"openai_passthrough_allow_timeout_headers"`
...
@@ -972,6 +1007,23 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
...
@@ -972,6 +1007,23 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg
.
LinuxDo
.
UserInfoEmailPath
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoEmailPath
)
cfg
.
LinuxDo
.
UserInfoEmailPath
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoEmailPath
)
cfg
.
LinuxDo
.
UserInfoIDPath
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoIDPath
)
cfg
.
LinuxDo
.
UserInfoIDPath
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoIDPath
)
cfg
.
LinuxDo
.
UserInfoUsernamePath
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoUsernamePath
)
cfg
.
LinuxDo
.
UserInfoUsernamePath
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
UserInfoUsernamePath
)
cfg
.
OIDC
.
ProviderName
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
ProviderName
)
cfg
.
OIDC
.
ClientID
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
ClientID
)
cfg
.
OIDC
.
ClientSecret
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
ClientSecret
)
cfg
.
OIDC
.
IssuerURL
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
IssuerURL
)
cfg
.
OIDC
.
DiscoveryURL
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
DiscoveryURL
)
cfg
.
OIDC
.
AuthorizeURL
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
AuthorizeURL
)
cfg
.
OIDC
.
TokenURL
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
TokenURL
)
cfg
.
OIDC
.
UserInfoURL
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
UserInfoURL
)
cfg
.
OIDC
.
JWKSURL
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
JWKSURL
)
cfg
.
OIDC
.
Scopes
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
Scopes
)
cfg
.
OIDC
.
RedirectURL
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
RedirectURL
)
cfg
.
OIDC
.
FrontendRedirectURL
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
FrontendRedirectURL
)
cfg
.
OIDC
.
TokenAuthMethod
=
strings
.
ToLower
(
strings
.
TrimSpace
(
cfg
.
OIDC
.
TokenAuthMethod
))
cfg
.
OIDC
.
AllowedSigningAlgs
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
AllowedSigningAlgs
)
cfg
.
OIDC
.
UserInfoEmailPath
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
UserInfoEmailPath
)
cfg
.
OIDC
.
UserInfoIDPath
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
UserInfoIDPath
)
cfg
.
OIDC
.
UserInfoUsernamePath
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
UserInfoUsernamePath
)
cfg
.
Dashboard
.
KeyPrefix
=
strings
.
TrimSpace
(
cfg
.
Dashboard
.
KeyPrefix
)
cfg
.
Dashboard
.
KeyPrefix
=
strings
.
TrimSpace
(
cfg
.
Dashboard
.
KeyPrefix
)
cfg
.
CORS
.
AllowedOrigins
=
normalizeStringSlice
(
cfg
.
CORS
.
AllowedOrigins
)
cfg
.
CORS
.
AllowedOrigins
=
normalizeStringSlice
(
cfg
.
CORS
.
AllowedOrigins
)
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
=
normalizeStringSlice
(
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
)
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
=
normalizeStringSlice
(
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
)
...
@@ -983,6 +1035,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
...
@@ -983,6 +1035,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg
.
Log
.
Environment
=
strings
.
TrimSpace
(
cfg
.
Log
.
Environment
)
cfg
.
Log
.
Environment
=
strings
.
TrimSpace
(
cfg
.
Log
.
Environment
)
cfg
.
Log
.
StacktraceLevel
=
strings
.
ToLower
(
strings
.
TrimSpace
(
cfg
.
Log
.
StacktraceLevel
))
cfg
.
Log
.
StacktraceLevel
=
strings
.
ToLower
(
strings
.
TrimSpace
(
cfg
.
Log
.
StacktraceLevel
))
cfg
.
Log
.
Output
.
FilePath
=
strings
.
TrimSpace
(
cfg
.
Log
.
Output
.
FilePath
)
cfg
.
Log
.
Output
.
FilePath
=
strings
.
TrimSpace
(
cfg
.
Log
.
Output
.
FilePath
)
cfg
.
Gateway
.
ForcedCodexInstructionsTemplateFile
=
strings
.
TrimSpace
(
cfg
.
Gateway
.
ForcedCodexInstructionsTemplateFile
)
if
cfg
.
Gateway
.
ForcedCodexInstructionsTemplateFile
!=
""
{
content
,
err
:=
os
.
ReadFile
(
cfg
.
Gateway
.
ForcedCodexInstructionsTemplateFile
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"read forced codex instructions template %q: %w"
,
cfg
.
Gateway
.
ForcedCodexInstructionsTemplateFile
,
err
)
}
cfg
.
Gateway
.
ForcedCodexInstructionsTemplate
=
string
(
content
)
}
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
// 新键未配置(<=0)时回退旧键;新键优先。
// 新键未配置(<=0)时回退旧键;新键优先。
...
@@ -1142,6 +1202,30 @@ func setDefaults() {
...
@@ -1142,6 +1202,30 @@ func setDefaults() {
viper
.
SetDefault
(
"linuxdo_connect.userinfo_id_path"
,
""
)
viper
.
SetDefault
(
"linuxdo_connect.userinfo_id_path"
,
""
)
viper
.
SetDefault
(
"linuxdo_connect.userinfo_username_path"
,
""
)
viper
.
SetDefault
(
"linuxdo_connect.userinfo_username_path"
,
""
)
// Generic OIDC OAuth 登录
viper
.
SetDefault
(
"oidc_connect.enabled"
,
false
)
viper
.
SetDefault
(
"oidc_connect.provider_name"
,
"OIDC"
)
viper
.
SetDefault
(
"oidc_connect.client_id"
,
""
)
viper
.
SetDefault
(
"oidc_connect.client_secret"
,
""
)
viper
.
SetDefault
(
"oidc_connect.issuer_url"
,
""
)
viper
.
SetDefault
(
"oidc_connect.discovery_url"
,
""
)
viper
.
SetDefault
(
"oidc_connect.authorize_url"
,
""
)
viper
.
SetDefault
(
"oidc_connect.token_url"
,
""
)
viper
.
SetDefault
(
"oidc_connect.userinfo_url"
,
""
)
viper
.
SetDefault
(
"oidc_connect.jwks_url"
,
""
)
viper
.
SetDefault
(
"oidc_connect.scopes"
,
"openid email profile"
)
viper
.
SetDefault
(
"oidc_connect.redirect_url"
,
""
)
viper
.
SetDefault
(
"oidc_connect.frontend_redirect_url"
,
"/auth/oidc/callback"
)
viper
.
SetDefault
(
"oidc_connect.token_auth_method"
,
"client_secret_post"
)
viper
.
SetDefault
(
"oidc_connect.use_pkce"
,
false
)
viper
.
SetDefault
(
"oidc_connect.validate_id_token"
,
true
)
viper
.
SetDefault
(
"oidc_connect.allowed_signing_algs"
,
"RS256,ES256,PS256"
)
viper
.
SetDefault
(
"oidc_connect.clock_skew_seconds"
,
120
)
viper
.
SetDefault
(
"oidc_connect.require_email_verified"
,
false
)
viper
.
SetDefault
(
"oidc_connect.userinfo_email_path"
,
""
)
viper
.
SetDefault
(
"oidc_connect.userinfo_id_path"
,
""
)
viper
.
SetDefault
(
"oidc_connect.userinfo_username_path"
,
""
)
// Database
// Database
viper
.
SetDefault
(
"database.host"
,
"localhost"
)
viper
.
SetDefault
(
"database.host"
,
"localhost"
)
viper
.
SetDefault
(
"database.port"
,
5432
)
viper
.
SetDefault
(
"database.port"
,
5432
)
...
@@ -1578,6 +1662,87 @@ func (c *Config) Validate() error {
...
@@ -1578,6 +1662,87 @@ func (c *Config) Validate() error {
warnIfInsecureURL
(
"linuxdo_connect.redirect_url"
,
c
.
LinuxDo
.
RedirectURL
)
warnIfInsecureURL
(
"linuxdo_connect.redirect_url"
,
c
.
LinuxDo
.
RedirectURL
)
warnIfInsecureURL
(
"linuxdo_connect.frontend_redirect_url"
,
c
.
LinuxDo
.
FrontendRedirectURL
)
warnIfInsecureURL
(
"linuxdo_connect.frontend_redirect_url"
,
c
.
LinuxDo
.
FrontendRedirectURL
)
}
}
if
c
.
OIDC
.
Enabled
{
if
strings
.
TrimSpace
(
c
.
OIDC
.
ClientID
)
==
""
{
return
fmt
.
Errorf
(
"oidc_connect.client_id is required when oidc_connect.enabled=true"
)
}
if
strings
.
TrimSpace
(
c
.
OIDC
.
IssuerURL
)
==
""
{
return
fmt
.
Errorf
(
"oidc_connect.issuer_url is required when oidc_connect.enabled=true"
)
}
if
strings
.
TrimSpace
(
c
.
OIDC
.
RedirectURL
)
==
""
{
return
fmt
.
Errorf
(
"oidc_connect.redirect_url is required when oidc_connect.enabled=true"
)
}
if
strings
.
TrimSpace
(
c
.
OIDC
.
FrontendRedirectURL
)
==
""
{
return
fmt
.
Errorf
(
"oidc_connect.frontend_redirect_url is required when oidc_connect.enabled=true"
)
}
if
!
scopeContainsOpenID
(
c
.
OIDC
.
Scopes
)
{
return
fmt
.
Errorf
(
"oidc_connect.scopes must contain openid"
)
}
method
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
OIDC
.
TokenAuthMethod
))
switch
method
{
case
""
,
"client_secret_post"
,
"client_secret_basic"
,
"none"
:
default
:
return
fmt
.
Errorf
(
"oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none"
)
}
if
method
==
"none"
&&
!
c
.
OIDC
.
UsePKCE
{
return
fmt
.
Errorf
(
"oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none"
)
}
if
(
method
==
""
||
method
==
"client_secret_post"
||
method
==
"client_secret_basic"
)
&&
strings
.
TrimSpace
(
c
.
OIDC
.
ClientSecret
)
==
""
{
return
fmt
.
Errorf
(
"oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic"
)
}
if
c
.
OIDC
.
ClockSkewSeconds
<
0
||
c
.
OIDC
.
ClockSkewSeconds
>
600
{
return
fmt
.
Errorf
(
"oidc_connect.clock_skew_seconds must be between 0 and 600"
)
}
if
c
.
OIDC
.
ValidateIDToken
&&
strings
.
TrimSpace
(
c
.
OIDC
.
AllowedSigningAlgs
)
==
""
{
return
fmt
.
Errorf
(
"oidc_connect.allowed_signing_algs is required when oidc_connect.validate_id_token=true"
)
}
if
err
:=
ValidateAbsoluteHTTPURL
(
c
.
OIDC
.
IssuerURL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"oidc_connect.issuer_url invalid: %w"
,
err
)
}
if
v
:=
strings
.
TrimSpace
(
c
.
OIDC
.
DiscoveryURL
);
v
!=
""
{
if
err
:=
ValidateAbsoluteHTTPURL
(
v
);
err
!=
nil
{
return
fmt
.
Errorf
(
"oidc_connect.discovery_url invalid: %w"
,
err
)
}
}
if
v
:=
strings
.
TrimSpace
(
c
.
OIDC
.
AuthorizeURL
);
v
!=
""
{
if
err
:=
ValidateAbsoluteHTTPURL
(
v
);
err
!=
nil
{
return
fmt
.
Errorf
(
"oidc_connect.authorize_url invalid: %w"
,
err
)
}
}
if
v
:=
strings
.
TrimSpace
(
c
.
OIDC
.
TokenURL
);
v
!=
""
{
if
err
:=
ValidateAbsoluteHTTPURL
(
v
);
err
!=
nil
{
return
fmt
.
Errorf
(
"oidc_connect.token_url invalid: %w"
,
err
)
}
}
if
v
:=
strings
.
TrimSpace
(
c
.
OIDC
.
UserInfoURL
);
v
!=
""
{
if
err
:=
ValidateAbsoluteHTTPURL
(
v
);
err
!=
nil
{
return
fmt
.
Errorf
(
"oidc_connect.userinfo_url invalid: %w"
,
err
)
}
}
if
v
:=
strings
.
TrimSpace
(
c
.
OIDC
.
JWKSURL
);
v
!=
""
{
if
err
:=
ValidateAbsoluteHTTPURL
(
v
);
err
!=
nil
{
return
fmt
.
Errorf
(
"oidc_connect.jwks_url invalid: %w"
,
err
)
}
}
if
err
:=
ValidateAbsoluteHTTPURL
(
c
.
OIDC
.
RedirectURL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"oidc_connect.redirect_url invalid: %w"
,
err
)
}
if
err
:=
ValidateFrontendRedirectURL
(
c
.
OIDC
.
FrontendRedirectURL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"oidc_connect.frontend_redirect_url invalid: %w"
,
err
)
}
warnIfInsecureURL
(
"oidc_connect.issuer_url"
,
c
.
OIDC
.
IssuerURL
)
warnIfInsecureURL
(
"oidc_connect.discovery_url"
,
c
.
OIDC
.
DiscoveryURL
)
warnIfInsecureURL
(
"oidc_connect.authorize_url"
,
c
.
OIDC
.
AuthorizeURL
)
warnIfInsecureURL
(
"oidc_connect.token_url"
,
c
.
OIDC
.
TokenURL
)
warnIfInsecureURL
(
"oidc_connect.userinfo_url"
,
c
.
OIDC
.
UserInfoURL
)
warnIfInsecureURL
(
"oidc_connect.jwks_url"
,
c
.
OIDC
.
JWKSURL
)
warnIfInsecureURL
(
"oidc_connect.redirect_url"
,
c
.
OIDC
.
RedirectURL
)
warnIfInsecureURL
(
"oidc_connect.frontend_redirect_url"
,
c
.
OIDC
.
FrontendRedirectURL
)
}
if
c
.
Billing
.
CircuitBreaker
.
Enabled
{
if
c
.
Billing
.
CircuitBreaker
.
Enabled
{
if
c
.
Billing
.
CircuitBreaker
.
FailureThreshold
<=
0
{
if
c
.
Billing
.
CircuitBreaker
.
FailureThreshold
<=
0
{
return
fmt
.
Errorf
(
"billing.circuit_breaker.failure_threshold must be positive"
)
return
fmt
.
Errorf
(
"billing.circuit_breaker.failure_threshold must be positive"
)
...
@@ -2196,6 +2361,15 @@ func ValidateFrontendRedirectURL(raw string) error {
...
@@ -2196,6 +2361,15 @@ func ValidateFrontendRedirectURL(raw string) error {
return
nil
return
nil
}
}
func
scopeContainsOpenID
(
scopes
string
)
bool
{
for
_
,
scope
:=
range
strings
.
Fields
(
strings
.
ToLower
(
strings
.
TrimSpace
(
scopes
)))
{
if
scope
==
"openid"
{
return
true
}
}
return
false
}
// isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议
// isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议
func
isHTTPScheme
(
scheme
string
)
bool
{
func
isHTTPScheme
(
scheme
string
)
bool
{
return
strings
.
EqualFold
(
scheme
,
"http"
)
||
strings
.
EqualFold
(
scheme
,
"https"
)
return
strings
.
EqualFold
(
scheme
,
"http"
)
||
strings
.
EqualFold
(
scheme
,
"https"
)
...
...
backend/internal/config/config_test.go
View file @
2b70d1d3
package
config
package
config
import
(
import
(
"os"
"path/filepath"
"strings"
"strings"
"testing"
"testing"
"time"
"time"
...
@@ -223,6 +225,23 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
...
@@ -223,6 +225,23 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
}
}
}
func
TestLoadForcedCodexInstructionsTemplate
(
t
*
testing
.
T
)
{
resetViperWithJWTSecret
(
t
)
tempDir
:=
t
.
TempDir
()
templatePath
:=
filepath
.
Join
(
tempDir
,
"codex-instructions.md.tmpl"
)
configPath
:=
filepath
.
Join
(
tempDir
,
"config.yaml"
)
require
.
NoError
(
t
,
os
.
WriteFile
(
templatePath
,
[]
byte
(
"server-prefix
\n\n
{{ .ExistingInstructions }}"
),
0
o644
))
require
.
NoError
(
t
,
os
.
WriteFile
(
configPath
,
[]
byte
(
"gateway:
\n
forced_codex_instructions_template_file:
\"
"
+
templatePath
+
"
\"\n
"
),
0
o644
))
t
.
Setenv
(
"DATA_DIR"
,
tempDir
)
cfg
,
err
:=
Load
()
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
templatePath
,
cfg
.
Gateway
.
ForcedCodexInstructionsTemplateFile
)
require
.
Equal
(
t
,
"server-prefix
\n\n
{{ .ExistingInstructions }}"
,
cfg
.
Gateway
.
ForcedCodexInstructionsTemplate
)
}
func
TestLoadDefaultSecurityToggles
(
t
*
testing
.
T
)
{
func
TestLoadDefaultSecurityToggles
(
t
*
testing
.
T
)
{
resetViperWithJWTSecret
(
t
)
resetViperWithJWTSecret
(
t
)
...
@@ -351,6 +370,60 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
...
@@ -351,6 +370,60 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
}
}
}
}
func
TestValidateOIDCScopesMustContainOpenID
(
t
*
testing
.
T
)
{
resetViperWithJWTSecret
(
t
)
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
cfg
.
OIDC
.
Enabled
=
true
cfg
.
OIDC
.
ClientID
=
"oidc-client"
cfg
.
OIDC
.
ClientSecret
=
"oidc-secret"
cfg
.
OIDC
.
IssuerURL
=
"https://issuer.example.com"
cfg
.
OIDC
.
AuthorizeURL
=
"https://issuer.example.com/auth"
cfg
.
OIDC
.
TokenURL
=
"https://issuer.example.com/token"
cfg
.
OIDC
.
JWKSURL
=
"https://issuer.example.com/jwks"
cfg
.
OIDC
.
RedirectURL
=
"https://example.com/api/v1/auth/oauth/oidc/callback"
cfg
.
OIDC
.
FrontendRedirectURL
=
"/auth/oidc/callback"
cfg
.
OIDC
.
Scopes
=
"profile email"
err
=
cfg
.
Validate
()
if
err
==
nil
{
t
.
Fatalf
(
"Validate() expected error when scopes do not include openid, got nil"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"oidc_connect.scopes"
)
{
t
.
Fatalf
(
"Validate() expected oidc_connect.scopes error, got: %v"
,
err
)
}
}
func
TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback
(
t
*
testing
.
T
)
{
resetViperWithJWTSecret
(
t
)
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
cfg
.
OIDC
.
Enabled
=
true
cfg
.
OIDC
.
ClientID
=
"oidc-client"
cfg
.
OIDC
.
ClientSecret
=
"oidc-secret"
cfg
.
OIDC
.
IssuerURL
=
"https://issuer.example.com"
cfg
.
OIDC
.
AuthorizeURL
=
""
cfg
.
OIDC
.
TokenURL
=
""
cfg
.
OIDC
.
JWKSURL
=
""
cfg
.
OIDC
.
RedirectURL
=
"https://example.com/api/v1/auth/oauth/oidc/callback"
cfg
.
OIDC
.
FrontendRedirectURL
=
"/auth/oidc/callback"
cfg
.
OIDC
.
Scopes
=
"openid email profile"
cfg
.
OIDC
.
ValidateIDToken
=
true
err
=
cfg
.
Validate
()
if
err
!=
nil
{
t
.
Fatalf
(
"Validate() expected issuer-only OIDC config to pass with discovery fallback, got: %v"
,
err
)
}
}
func
TestLoadDefaultDashboardCacheConfig
(
t
*
testing
.
T
)
{
func
TestLoadDefaultDashboardCacheConfig
(
t
*
testing
.
T
)
{
resetViperWithJWTSecret
(
t
)
resetViperWithJWTSecret
(
t
)
...
...
backend/internal/domain/openai_messages_dispatch.go
0 → 100644
View file @
2b70d1d3
package
domain
// OpenAIMessagesDispatchModelConfig controls how Anthropic /v1/messages
// requests are mapped onto OpenAI/Codex models.
type
OpenAIMessagesDispatchModelConfig
struct
{
OpusMappedModel
string
`json:"opus_mapped_model,omitempty"`
SonnetMappedModel
string
`json:"sonnet_mapped_model,omitempty"`
HaikuMappedModel
string
`json:"haiku_mapped_model,omitempty"`
ExactModelMappings
map
[
string
]
string
`json:"exact_model_mappings,omitempty"`
}
backend/internal/handler/admin/group_handler.go
View file @
2b70d1d3
...
@@ -109,6 +109,7 @@ type CreateGroupRequest struct {
...
@@ -109,6 +109,7 @@ type CreateGroupRequest struct {
RequireOAuthOnly
bool
`json:"require_oauth_only"`
RequireOAuthOnly
bool
`json:"require_oauth_only"`
RequirePrivacySet
bool
`json:"require_privacy_set"`
RequirePrivacySet
bool
`json:"require_privacy_set"`
DefaultMappedModel
string
`json:"default_mapped_model"`
DefaultMappedModel
string
`json:"default_mapped_model"`
MessagesDispatchModelConfig
service
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config"`
// 从指定分组复制账号(创建后自动绑定)
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
}
...
@@ -143,6 +144,7 @@ type UpdateGroupRequest struct {
...
@@ -143,6 +144,7 @@ type UpdateGroupRequest struct {
RequireOAuthOnly
*
bool
`json:"require_oauth_only"`
RequireOAuthOnly
*
bool
`json:"require_oauth_only"`
RequirePrivacySet
*
bool
`json:"require_privacy_set"`
RequirePrivacySet
*
bool
`json:"require_privacy_set"`
DefaultMappedModel
*
string
`json:"default_mapped_model"`
DefaultMappedModel
*
string
`json:"default_mapped_model"`
MessagesDispatchModelConfig
*
service
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
}
...
@@ -259,6 +261,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
...
@@ -259,6 +261,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequireOAuthOnly
:
req
.
RequireOAuthOnly
,
RequireOAuthOnly
:
req
.
RequireOAuthOnly
,
RequirePrivacySet
:
req
.
RequirePrivacySet
,
RequirePrivacySet
:
req
.
RequirePrivacySet
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
req
.
MessagesDispatchModelConfig
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -309,6 +312,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
...
@@ -309,6 +312,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequireOAuthOnly
:
req
.
RequireOAuthOnly
,
RequireOAuthOnly
:
req
.
RequireOAuthOnly
,
RequirePrivacySet
:
req
.
RequirePrivacySet
,
RequirePrivacySet
:
req
.
RequirePrivacySet
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
DefaultMappedModel
:
req
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
req
.
MessagesDispatchModelConfig
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
...
...
backend/internal/handler/admin/setting_handler.go
View file @
2b70d1d3
...
@@ -35,6 +35,15 @@ func generateMenuItemID() (string, error) {
...
@@ -35,6 +35,15 @@ func generateMenuItemID() (string, error) {
return
hex
.
EncodeToString
(
b
),
nil
return
hex
.
EncodeToString
(
b
),
nil
}
}
func
scopesContainOpenID
(
scopes
string
)
bool
{
for
_
,
scope
:=
range
strings
.
Fields
(
strings
.
ToLower
(
strings
.
TrimSpace
(
scopes
)))
{
if
scope
==
"openid"
{
return
true
}
}
return
false
}
// SettingHandler 系统设置处理器
// SettingHandler 系统设置处理器
type
SettingHandler
struct
{
type
SettingHandler
struct
{
settingService
*
service
.
SettingService
settingService
*
service
.
SettingService
...
@@ -96,6 +105,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
...
@@ -96,6 +105,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
LinuxDoConnectClientID
:
settings
.
LinuxDoConnectClientID
,
LinuxDoConnectClientID
:
settings
.
LinuxDoConnectClientID
,
LinuxDoConnectClientSecretConfigured
:
settings
.
LinuxDoConnectClientSecretConfigured
,
LinuxDoConnectClientSecretConfigured
:
settings
.
LinuxDoConnectClientSecretConfigured
,
LinuxDoConnectRedirectURL
:
settings
.
LinuxDoConnectRedirectURL
,
LinuxDoConnectRedirectURL
:
settings
.
LinuxDoConnectRedirectURL
,
OIDCConnectEnabled
:
settings
.
OIDCConnectEnabled
,
OIDCConnectProviderName
:
settings
.
OIDCConnectProviderName
,
OIDCConnectClientID
:
settings
.
OIDCConnectClientID
,
OIDCConnectClientSecretConfigured
:
settings
.
OIDCConnectClientSecretConfigured
,
OIDCConnectIssuerURL
:
settings
.
OIDCConnectIssuerURL
,
OIDCConnectDiscoveryURL
:
settings
.
OIDCConnectDiscoveryURL
,
OIDCConnectAuthorizeURL
:
settings
.
OIDCConnectAuthorizeURL
,
OIDCConnectTokenURL
:
settings
.
OIDCConnectTokenURL
,
OIDCConnectUserInfoURL
:
settings
.
OIDCConnectUserInfoURL
,
OIDCConnectJWKSURL
:
settings
.
OIDCConnectJWKSURL
,
OIDCConnectScopes
:
settings
.
OIDCConnectScopes
,
OIDCConnectRedirectURL
:
settings
.
OIDCConnectRedirectURL
,
OIDCConnectFrontendRedirectURL
:
settings
.
OIDCConnectFrontendRedirectURL
,
OIDCConnectTokenAuthMethod
:
settings
.
OIDCConnectTokenAuthMethod
,
OIDCConnectUsePKCE
:
settings
.
OIDCConnectUsePKCE
,
OIDCConnectValidateIDToken
:
settings
.
OIDCConnectValidateIDToken
,
OIDCConnectAllowedSigningAlgs
:
settings
.
OIDCConnectAllowedSigningAlgs
,
OIDCConnectClockSkewSeconds
:
settings
.
OIDCConnectClockSkewSeconds
,
OIDCConnectRequireEmailVerified
:
settings
.
OIDCConnectRequireEmailVerified
,
OIDCConnectUserInfoEmailPath
:
settings
.
OIDCConnectUserInfoEmailPath
,
OIDCConnectUserInfoIDPath
:
settings
.
OIDCConnectUserInfoIDPath
,
OIDCConnectUserInfoUsernamePath
:
settings
.
OIDCConnectUserInfoUsernamePath
,
SiteName
:
settings
.
SiteName
,
SiteName
:
settings
.
SiteName
,
SiteLogo
:
settings
.
SiteLogo
,
SiteLogo
:
settings
.
SiteLogo
,
SiteSubtitle
:
settings
.
SiteSubtitle
,
SiteSubtitle
:
settings
.
SiteSubtitle
,
...
@@ -166,6 +197,30 @@ type UpdateSettingsRequest struct {
...
@@ -166,6 +197,30 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret
string
`json:"linuxdo_connect_client_secret"`
LinuxDoConnectClientSecret
string
`json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL
string
`json:"linuxdo_connect_redirect_url"`
LinuxDoConnectRedirectURL
string
`json:"linuxdo_connect_redirect_url"`
// Generic OIDC OAuth 登录
OIDCConnectEnabled
bool
`json:"oidc_connect_enabled"`
OIDCConnectProviderName
string
`json:"oidc_connect_provider_name"`
OIDCConnectClientID
string
`json:"oidc_connect_client_id"`
OIDCConnectClientSecret
string
`json:"oidc_connect_client_secret"`
OIDCConnectIssuerURL
string
`json:"oidc_connect_issuer_url"`
OIDCConnectDiscoveryURL
string
`json:"oidc_connect_discovery_url"`
OIDCConnectAuthorizeURL
string
`json:"oidc_connect_authorize_url"`
OIDCConnectTokenURL
string
`json:"oidc_connect_token_url"`
OIDCConnectUserInfoURL
string
`json:"oidc_connect_userinfo_url"`
OIDCConnectJWKSURL
string
`json:"oidc_connect_jwks_url"`
OIDCConnectScopes
string
`json:"oidc_connect_scopes"`
OIDCConnectRedirectURL
string
`json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL
string
`json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod
string
`json:"oidc_connect_token_auth_method"`
OIDCConnectUsePKCE
bool
`json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken
bool
`json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs
string
`json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds
int
`json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified
bool
`json:"oidc_connect_require_email_verified"`
OIDCConnectUserInfoEmailPath
string
`json:"oidc_connect_userinfo_email_path"`
OIDCConnectUserInfoIDPath
string
`json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath
string
`json:"oidc_connect_userinfo_username_path"`
// OEM设置
// OEM设置
SiteName
string
`json:"site_name"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteLogo
string
`json:"site_logo"`
...
@@ -335,6 +390,122 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
...
@@ -335,6 +390,122 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
}
}
// Generic OIDC 参数验证
if
req
.
OIDCConnectEnabled
{
req
.
OIDCConnectProviderName
=
strings
.
TrimSpace
(
req
.
OIDCConnectProviderName
)
req
.
OIDCConnectClientID
=
strings
.
TrimSpace
(
req
.
OIDCConnectClientID
)
req
.
OIDCConnectClientSecret
=
strings
.
TrimSpace
(
req
.
OIDCConnectClientSecret
)
req
.
OIDCConnectIssuerURL
=
strings
.
TrimSpace
(
req
.
OIDCConnectIssuerURL
)
req
.
OIDCConnectDiscoveryURL
=
strings
.
TrimSpace
(
req
.
OIDCConnectDiscoveryURL
)
req
.
OIDCConnectAuthorizeURL
=
strings
.
TrimSpace
(
req
.
OIDCConnectAuthorizeURL
)
req
.
OIDCConnectTokenURL
=
strings
.
TrimSpace
(
req
.
OIDCConnectTokenURL
)
req
.
OIDCConnectUserInfoURL
=
strings
.
TrimSpace
(
req
.
OIDCConnectUserInfoURL
)
req
.
OIDCConnectJWKSURL
=
strings
.
TrimSpace
(
req
.
OIDCConnectJWKSURL
)
req
.
OIDCConnectScopes
=
strings
.
TrimSpace
(
req
.
OIDCConnectScopes
)
req
.
OIDCConnectRedirectURL
=
strings
.
TrimSpace
(
req
.
OIDCConnectRedirectURL
)
req
.
OIDCConnectFrontendRedirectURL
=
strings
.
TrimSpace
(
req
.
OIDCConnectFrontendRedirectURL
)
req
.
OIDCConnectTokenAuthMethod
=
strings
.
ToLower
(
strings
.
TrimSpace
(
req
.
OIDCConnectTokenAuthMethod
))
req
.
OIDCConnectAllowedSigningAlgs
=
strings
.
TrimSpace
(
req
.
OIDCConnectAllowedSigningAlgs
)
req
.
OIDCConnectUserInfoEmailPath
=
strings
.
TrimSpace
(
req
.
OIDCConnectUserInfoEmailPath
)
req
.
OIDCConnectUserInfoIDPath
=
strings
.
TrimSpace
(
req
.
OIDCConnectUserInfoIDPath
)
req
.
OIDCConnectUserInfoUsernamePath
=
strings
.
TrimSpace
(
req
.
OIDCConnectUserInfoUsernamePath
)
if
req
.
OIDCConnectProviderName
==
""
{
req
.
OIDCConnectProviderName
=
"OIDC"
}
if
req
.
OIDCConnectClientID
==
""
{
response
.
BadRequest
(
c
,
"OIDC Client ID is required when enabled"
)
return
}
if
req
.
OIDCConnectIssuerURL
==
""
{
response
.
BadRequest
(
c
,
"OIDC Issuer URL is required when enabled"
)
return
}
if
err
:=
config
.
ValidateAbsoluteHTTPURL
(
req
.
OIDCConnectIssuerURL
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"OIDC Issuer URL must be an absolute http(s) URL"
)
return
}
if
req
.
OIDCConnectDiscoveryURL
!=
""
{
if
err
:=
config
.
ValidateAbsoluteHTTPURL
(
req
.
OIDCConnectDiscoveryURL
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"OIDC Discovery URL must be an absolute http(s) URL"
)
return
}
}
if
req
.
OIDCConnectAuthorizeURL
!=
""
{
if
err
:=
config
.
ValidateAbsoluteHTTPURL
(
req
.
OIDCConnectAuthorizeURL
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"OIDC Authorize URL must be an absolute http(s) URL"
)
return
}
}
if
req
.
OIDCConnectTokenURL
!=
""
{
if
err
:=
config
.
ValidateAbsoluteHTTPURL
(
req
.
OIDCConnectTokenURL
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"OIDC Token URL must be an absolute http(s) URL"
)
return
}
}
if
req
.
OIDCConnectUserInfoURL
!=
""
{
if
err
:=
config
.
ValidateAbsoluteHTTPURL
(
req
.
OIDCConnectUserInfoURL
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"OIDC UserInfo URL must be an absolute http(s) URL"
)
return
}
}
if
req
.
OIDCConnectRedirectURL
==
""
{
response
.
BadRequest
(
c
,
"OIDC Redirect URL is required when enabled"
)
return
}
if
err
:=
config
.
ValidateAbsoluteHTTPURL
(
req
.
OIDCConnectRedirectURL
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"OIDC Redirect URL must be an absolute http(s) URL"
)
return
}
if
req
.
OIDCConnectFrontendRedirectURL
==
""
{
response
.
BadRequest
(
c
,
"OIDC Frontend Redirect URL is required when enabled"
)
return
}
if
err
:=
config
.
ValidateFrontendRedirectURL
(
req
.
OIDCConnectFrontendRedirectURL
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"OIDC Frontend Redirect URL is invalid"
)
return
}
if
!
scopesContainOpenID
(
req
.
OIDCConnectScopes
)
{
response
.
BadRequest
(
c
,
"OIDC scopes must contain openid"
)
return
}
switch
req
.
OIDCConnectTokenAuthMethod
{
case
""
,
"client_secret_post"
,
"client_secret_basic"
,
"none"
:
default
:
response
.
BadRequest
(
c
,
"OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none"
)
return
}
if
req
.
OIDCConnectTokenAuthMethod
==
"none"
&&
!
req
.
OIDCConnectUsePKCE
{
response
.
BadRequest
(
c
,
"OIDC PKCE must be enabled when token_auth_method=none"
)
return
}
if
req
.
OIDCConnectClockSkewSeconds
<
0
||
req
.
OIDCConnectClockSkewSeconds
>
600
{
response
.
BadRequest
(
c
,
"OIDC clock skew seconds must be between 0 and 600"
)
return
}
if
req
.
OIDCConnectValidateIDToken
{
if
req
.
OIDCConnectAllowedSigningAlgs
==
""
{
response
.
BadRequest
(
c
,
"OIDC Allowed Signing Algs is required when validate_id_token=true"
)
return
}
}
if
req
.
OIDCConnectJWKSURL
!=
""
{
if
err
:=
config
.
ValidateAbsoluteHTTPURL
(
req
.
OIDCConnectJWKSURL
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"OIDC JWKS URL must be an absolute http(s) URL"
)
return
}
}
if
req
.
OIDCConnectTokenAuthMethod
==
""
||
req
.
OIDCConnectTokenAuthMethod
==
"client_secret_post"
||
req
.
OIDCConnectTokenAuthMethod
==
"client_secret_basic"
{
if
req
.
OIDCConnectClientSecret
==
""
{
if
previousSettings
.
OIDCConnectClientSecret
==
""
{
response
.
BadRequest
(
c
,
"OIDC Client Secret is required when enabled"
)
return
}
req
.
OIDCConnectClientSecret
=
previousSettings
.
OIDCConnectClientSecret
}
}
}
// “购买订阅”页面配置验证
// “购买订阅”页面配置验证
purchaseEnabled
:=
previousSettings
.
PurchaseSubscriptionEnabled
purchaseEnabled
:=
previousSettings
.
PurchaseSubscriptionEnabled
if
req
.
PurchaseSubscriptionEnabled
!=
nil
{
if
req
.
PurchaseSubscriptionEnabled
!=
nil
{
...
@@ -565,6 +736,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
...
@@ -565,6 +736,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID
:
req
.
LinuxDoConnectClientID
,
LinuxDoConnectClientID
:
req
.
LinuxDoConnectClientID
,
LinuxDoConnectClientSecret
:
req
.
LinuxDoConnectClientSecret
,
LinuxDoConnectClientSecret
:
req
.
LinuxDoConnectClientSecret
,
LinuxDoConnectRedirectURL
:
req
.
LinuxDoConnectRedirectURL
,
LinuxDoConnectRedirectURL
:
req
.
LinuxDoConnectRedirectURL
,
OIDCConnectEnabled
:
req
.
OIDCConnectEnabled
,
OIDCConnectProviderName
:
req
.
OIDCConnectProviderName
,
OIDCConnectClientID
:
req
.
OIDCConnectClientID
,
OIDCConnectClientSecret
:
req
.
OIDCConnectClientSecret
,
OIDCConnectIssuerURL
:
req
.
OIDCConnectIssuerURL
,
OIDCConnectDiscoveryURL
:
req
.
OIDCConnectDiscoveryURL
,
OIDCConnectAuthorizeURL
:
req
.
OIDCConnectAuthorizeURL
,
OIDCConnectTokenURL
:
req
.
OIDCConnectTokenURL
,
OIDCConnectUserInfoURL
:
req
.
OIDCConnectUserInfoURL
,
OIDCConnectJWKSURL
:
req
.
OIDCConnectJWKSURL
,
OIDCConnectScopes
:
req
.
OIDCConnectScopes
,
OIDCConnectRedirectURL
:
req
.
OIDCConnectRedirectURL
,
OIDCConnectFrontendRedirectURL
:
req
.
OIDCConnectFrontendRedirectURL
,
OIDCConnectTokenAuthMethod
:
req
.
OIDCConnectTokenAuthMethod
,
OIDCConnectUsePKCE
:
req
.
OIDCConnectUsePKCE
,
OIDCConnectValidateIDToken
:
req
.
OIDCConnectValidateIDToken
,
OIDCConnectAllowedSigningAlgs
:
req
.
OIDCConnectAllowedSigningAlgs
,
OIDCConnectClockSkewSeconds
:
req
.
OIDCConnectClockSkewSeconds
,
OIDCConnectRequireEmailVerified
:
req
.
OIDCConnectRequireEmailVerified
,
OIDCConnectUserInfoEmailPath
:
req
.
OIDCConnectUserInfoEmailPath
,
OIDCConnectUserInfoIDPath
:
req
.
OIDCConnectUserInfoIDPath
,
OIDCConnectUserInfoUsernamePath
:
req
.
OIDCConnectUserInfoUsernamePath
,
SiteName
:
req
.
SiteName
,
SiteName
:
req
.
SiteName
,
SiteLogo
:
req
.
SiteLogo
,
SiteLogo
:
req
.
SiteLogo
,
SiteSubtitle
:
req
.
SiteSubtitle
,
SiteSubtitle
:
req
.
SiteSubtitle
,
...
@@ -682,6 +875,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
...
@@ -682,6 +875,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID
:
updatedSettings
.
LinuxDoConnectClientID
,
LinuxDoConnectClientID
:
updatedSettings
.
LinuxDoConnectClientID
,
LinuxDoConnectClientSecretConfigured
:
updatedSettings
.
LinuxDoConnectClientSecretConfigured
,
LinuxDoConnectClientSecretConfigured
:
updatedSettings
.
LinuxDoConnectClientSecretConfigured
,
LinuxDoConnectRedirectURL
:
updatedSettings
.
LinuxDoConnectRedirectURL
,
LinuxDoConnectRedirectURL
:
updatedSettings
.
LinuxDoConnectRedirectURL
,
OIDCConnectEnabled
:
updatedSettings
.
OIDCConnectEnabled
,
OIDCConnectProviderName
:
updatedSettings
.
OIDCConnectProviderName
,
OIDCConnectClientID
:
updatedSettings
.
OIDCConnectClientID
,
OIDCConnectClientSecretConfigured
:
updatedSettings
.
OIDCConnectClientSecretConfigured
,
OIDCConnectIssuerURL
:
updatedSettings
.
OIDCConnectIssuerURL
,
OIDCConnectDiscoveryURL
:
updatedSettings
.
OIDCConnectDiscoveryURL
,
OIDCConnectAuthorizeURL
:
updatedSettings
.
OIDCConnectAuthorizeURL
,
OIDCConnectTokenURL
:
updatedSettings
.
OIDCConnectTokenURL
,
OIDCConnectUserInfoURL
:
updatedSettings
.
OIDCConnectUserInfoURL
,
OIDCConnectJWKSURL
:
updatedSettings
.
OIDCConnectJWKSURL
,
OIDCConnectScopes
:
updatedSettings
.
OIDCConnectScopes
,
OIDCConnectRedirectURL
:
updatedSettings
.
OIDCConnectRedirectURL
,
OIDCConnectFrontendRedirectURL
:
updatedSettings
.
OIDCConnectFrontendRedirectURL
,
OIDCConnectTokenAuthMethod
:
updatedSettings
.
OIDCConnectTokenAuthMethod
,
OIDCConnectUsePKCE
:
updatedSettings
.
OIDCConnectUsePKCE
,
OIDCConnectValidateIDToken
:
updatedSettings
.
OIDCConnectValidateIDToken
,
OIDCConnectAllowedSigningAlgs
:
updatedSettings
.
OIDCConnectAllowedSigningAlgs
,
OIDCConnectClockSkewSeconds
:
updatedSettings
.
OIDCConnectClockSkewSeconds
,
OIDCConnectRequireEmailVerified
:
updatedSettings
.
OIDCConnectRequireEmailVerified
,
OIDCConnectUserInfoEmailPath
:
updatedSettings
.
OIDCConnectUserInfoEmailPath
,
OIDCConnectUserInfoIDPath
:
updatedSettings
.
OIDCConnectUserInfoIDPath
,
OIDCConnectUserInfoUsernamePath
:
updatedSettings
.
OIDCConnectUserInfoUsernamePath
,
SiteName
:
updatedSettings
.
SiteName
,
SiteName
:
updatedSettings
.
SiteName
,
SiteLogo
:
updatedSettings
.
SiteLogo
,
SiteLogo
:
updatedSettings
.
SiteLogo
,
SiteSubtitle
:
updatedSettings
.
SiteSubtitle
,
SiteSubtitle
:
updatedSettings
.
SiteSubtitle
,
...
@@ -802,6 +1017,72 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
...
@@ -802,6 +1017,72 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if
before
.
LinuxDoConnectRedirectURL
!=
after
.
LinuxDoConnectRedirectURL
{
if
before
.
LinuxDoConnectRedirectURL
!=
after
.
LinuxDoConnectRedirectURL
{
changed
=
append
(
changed
,
"linuxdo_connect_redirect_url"
)
changed
=
append
(
changed
,
"linuxdo_connect_redirect_url"
)
}
}
if
before
.
OIDCConnectEnabled
!=
after
.
OIDCConnectEnabled
{
changed
=
append
(
changed
,
"oidc_connect_enabled"
)
}
if
before
.
OIDCConnectProviderName
!=
after
.
OIDCConnectProviderName
{
changed
=
append
(
changed
,
"oidc_connect_provider_name"
)
}
if
before
.
OIDCConnectClientID
!=
after
.
OIDCConnectClientID
{
changed
=
append
(
changed
,
"oidc_connect_client_id"
)
}
if
req
.
OIDCConnectClientSecret
!=
""
{
changed
=
append
(
changed
,
"oidc_connect_client_secret"
)
}
if
before
.
OIDCConnectIssuerURL
!=
after
.
OIDCConnectIssuerURL
{
changed
=
append
(
changed
,
"oidc_connect_issuer_url"
)
}
if
before
.
OIDCConnectDiscoveryURL
!=
after
.
OIDCConnectDiscoveryURL
{
changed
=
append
(
changed
,
"oidc_connect_discovery_url"
)
}
if
before
.
OIDCConnectAuthorizeURL
!=
after
.
OIDCConnectAuthorizeURL
{
changed
=
append
(
changed
,
"oidc_connect_authorize_url"
)
}
if
before
.
OIDCConnectTokenURL
!=
after
.
OIDCConnectTokenURL
{
changed
=
append
(
changed
,
"oidc_connect_token_url"
)
}
if
before
.
OIDCConnectUserInfoURL
!=
after
.
OIDCConnectUserInfoURL
{
changed
=
append
(
changed
,
"oidc_connect_userinfo_url"
)
}
if
before
.
OIDCConnectJWKSURL
!=
after
.
OIDCConnectJWKSURL
{
changed
=
append
(
changed
,
"oidc_connect_jwks_url"
)
}
if
before
.
OIDCConnectScopes
!=
after
.
OIDCConnectScopes
{
changed
=
append
(
changed
,
"oidc_connect_scopes"
)
}
if
before
.
OIDCConnectRedirectURL
!=
after
.
OIDCConnectRedirectURL
{
changed
=
append
(
changed
,
"oidc_connect_redirect_url"
)
}
if
before
.
OIDCConnectFrontendRedirectURL
!=
after
.
OIDCConnectFrontendRedirectURL
{
changed
=
append
(
changed
,
"oidc_connect_frontend_redirect_url"
)
}
if
before
.
OIDCConnectTokenAuthMethod
!=
after
.
OIDCConnectTokenAuthMethod
{
changed
=
append
(
changed
,
"oidc_connect_token_auth_method"
)
}
if
before
.
OIDCConnectUsePKCE
!=
after
.
OIDCConnectUsePKCE
{
changed
=
append
(
changed
,
"oidc_connect_use_pkce"
)
}
if
before
.
OIDCConnectValidateIDToken
!=
after
.
OIDCConnectValidateIDToken
{
changed
=
append
(
changed
,
"oidc_connect_validate_id_token"
)
}
if
before
.
OIDCConnectAllowedSigningAlgs
!=
after
.
OIDCConnectAllowedSigningAlgs
{
changed
=
append
(
changed
,
"oidc_connect_allowed_signing_algs"
)
}
if
before
.
OIDCConnectClockSkewSeconds
!=
after
.
OIDCConnectClockSkewSeconds
{
changed
=
append
(
changed
,
"oidc_connect_clock_skew_seconds"
)
}
if
before
.
OIDCConnectRequireEmailVerified
!=
after
.
OIDCConnectRequireEmailVerified
{
changed
=
append
(
changed
,
"oidc_connect_require_email_verified"
)
}
if
before
.
OIDCConnectUserInfoEmailPath
!=
after
.
OIDCConnectUserInfoEmailPath
{
changed
=
append
(
changed
,
"oidc_connect_userinfo_email_path"
)
}
if
before
.
OIDCConnectUserInfoIDPath
!=
after
.
OIDCConnectUserInfoIDPath
{
changed
=
append
(
changed
,
"oidc_connect_userinfo_id_path"
)
}
if
before
.
OIDCConnectUserInfoUsernamePath
!=
after
.
OIDCConnectUserInfoUsernamePath
{
changed
=
append
(
changed
,
"oidc_connect_userinfo_username_path"
)
}
if
before
.
SiteName
!=
after
.
SiteName
{
if
before
.
SiteName
!=
after
.
SiteName
{
changed
=
append
(
changed
,
"site_name"
)
changed
=
append
(
changed
,
"site_name"
)
}
}
...
...
backend/internal/handler/auth_oidc_oauth.go
0 → 100644
View file @
2b70d1d3
package
handler
import
(
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"math/big"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const
(
oidcOAuthCookiePath
=
"/api/v1/auth/oauth/oidc"
oidcOAuthStateCookieName
=
"oidc_oauth_state"
oidcOAuthVerifierCookie
=
"oidc_oauth_verifier"
oidcOAuthRedirectCookie
=
"oidc_oauth_redirect"
oidcOAuthNonceCookie
=
"oidc_oauth_nonce"
oidcOAuthCookieMaxAgeSec
=
10
*
60
// 10 minutes
oidcOAuthDefaultRedirectTo
=
"/dashboard"
oidcOAuthDefaultFrontendCB
=
"/auth/oidc/callback"
)
type
oidcTokenResponse
struct
{
AccessToken
string
`json:"access_token"`
TokenType
string
`json:"token_type"`
ExpiresIn
int64
`json:"expires_in"`
RefreshToken
string
`json:"refresh_token,omitempty"`
Scope
string
`json:"scope,omitempty"`
IDToken
string
`json:"id_token,omitempty"`
}
type
oidcTokenExchangeError
struct
{
StatusCode
int
ProviderError
string
ProviderDescription
string
Body
string
}
func
(
e
*
oidcTokenExchangeError
)
Error
()
string
{
if
e
==
nil
{
return
""
}
parts
:=
[]
string
{
fmt
.
Sprintf
(
"token exchange status=%d"
,
e
.
StatusCode
)}
if
strings
.
TrimSpace
(
e
.
ProviderError
)
!=
""
{
parts
=
append
(
parts
,
"error="
+
strings
.
TrimSpace
(
e
.
ProviderError
))
}
if
strings
.
TrimSpace
(
e
.
ProviderDescription
)
!=
""
{
parts
=
append
(
parts
,
"error_description="
+
strings
.
TrimSpace
(
e
.
ProviderDescription
))
}
return
strings
.
Join
(
parts
,
" "
)
}
type
oidcIDTokenClaims
struct
{
Email
string
`json:"email,omitempty"`
EmailVerified
*
bool
`json:"email_verified,omitempty"`
PreferredUsername
string
`json:"preferred_username,omitempty"`
Name
string
`json:"name,omitempty"`
Nonce
string
`json:"nonce,omitempty"`
Azp
string
`json:"azp,omitempty"`
jwt
.
RegisteredClaims
}
type
oidcUserInfoClaims
struct
{
Email
string
Username
string
Subject
string
EmailVerified
*
bool
}
type
oidcJWKSet
struct
{
Keys
[]
oidcJWK
`json:"keys"`
}
type
oidcJWK
struct
{
Kty
string
`json:"kty"`
Kid
string
`json:"kid"`
Use
string
`json:"use"`
Alg
string
`json:"alg"`
N
string
`json:"n"`
E
string
`json:"e"`
Crv
string
`json:"crv"`
X
string
`json:"x"`
Y
string
`json:"y"`
}
// OIDCOAuthStart 启动通用 OIDC OAuth 登录流程。
// GET /api/v1/auth/oauth/oidc/start?redirect=/dashboard
func
(
h
*
AuthHandler
)
OIDCOAuthStart
(
c
*
gin
.
Context
)
{
cfg
,
err
:=
h
.
getOIDCOAuthConfig
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
state
,
err
:=
oauth
.
GenerateState
()
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_STATE_GEN_FAILED"
,
"failed to generate oauth state"
)
.
WithCause
(
err
))
return
}
redirectTo
:=
sanitizeFrontendRedirectPath
(
c
.
Query
(
"redirect"
))
if
redirectTo
==
""
{
redirectTo
=
oidcOAuthDefaultRedirectTo
}
secureCookie
:=
isRequestHTTPS
(
c
)
oidcSetCookie
(
c
,
oidcOAuthStateCookieName
,
encodeCookieValue
(
state
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
oidcSetCookie
(
c
,
oidcOAuthRedirectCookie
,
encodeCookieValue
(
redirectTo
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
codeChallenge
:=
""
if
cfg
.
UsePKCE
{
verifier
,
genErr
:=
oauth
.
GenerateCodeVerifier
()
if
genErr
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_PKCE_GEN_FAILED"
,
"failed to generate pkce verifier"
)
.
WithCause
(
genErr
))
return
}
codeChallenge
=
oauth
.
GenerateCodeChallenge
(
verifier
)
oidcSetCookie
(
c
,
oidcOAuthVerifierCookie
,
encodeCookieValue
(
verifier
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
}
nonce
:=
""
if
cfg
.
ValidateIDToken
{
nonce
,
err
=
oauth
.
GenerateState
()
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_NONCE_GEN_FAILED"
,
"failed to generate oauth nonce"
)
.
WithCause
(
err
))
return
}
oidcSetCookie
(
c
,
oidcOAuthNonceCookie
,
encodeCookieValue
(
nonce
),
oidcOAuthCookieMaxAgeSec
,
secureCookie
)
}
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
if
redirectURI
==
""
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_CONFIG_INVALID"
,
"oauth redirect url not configured"
))
return
}
authURL
,
err
:=
buildOIDCAuthorizeURL
(
cfg
,
state
,
nonce
,
codeChallenge
,
redirectURI
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_BUILD_URL_FAILED"
,
"failed to build oauth authorization url"
)
.
WithCause
(
err
))
return
}
c
.
Redirect
(
http
.
StatusFound
,
authURL
)
}
// OIDCOAuthCallback 处理 OIDC 回调:校验 id_token、创建/登录用户并重定向到前端。
// GET /api/v1/auth/oauth/oidc/callback?code=...&state=...
func
(
h
*
AuthHandler
)
OIDCOAuthCallback
(
c
*
gin
.
Context
)
{
cfg
,
cfgErr
:=
h
.
getOIDCOAuthConfig
(
c
.
Request
.
Context
())
if
cfgErr
!=
nil
{
response
.
ErrorFrom
(
c
,
cfgErr
)
return
}
frontendCallback
:=
strings
.
TrimSpace
(
cfg
.
FrontendRedirectURL
)
if
frontendCallback
==
""
{
frontendCallback
=
oidcOAuthDefaultFrontendCB
}
if
providerErr
:=
strings
.
TrimSpace
(
c
.
Query
(
"error"
));
providerErr
!=
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"provider_error"
,
providerErr
,
c
.
Query
(
"error_description"
))
return
}
code
:=
strings
.
TrimSpace
(
c
.
Query
(
"code"
))
state
:=
strings
.
TrimSpace
(
c
.
Query
(
"state"
))
if
code
==
""
||
state
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_params"
,
"missing code/state"
,
""
)
return
}
secureCookie
:=
isRequestHTTPS
(
c
)
defer
func
()
{
oidcClearCookie
(
c
,
oidcOAuthStateCookieName
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthVerifierCookie
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthRedirectCookie
,
secureCookie
)
oidcClearCookie
(
c
,
oidcOAuthNonceCookie
,
secureCookie
)
}()
expectedState
,
err
:=
readCookieDecoded
(
c
,
oidcOAuthStateCookieName
)
if
err
!=
nil
||
expectedState
==
""
||
state
!=
expectedState
{
redirectOAuthError
(
c
,
frontendCallback
,
"invalid_state"
,
"invalid oauth state"
,
""
)
return
}
redirectTo
,
_
:=
readCookieDecoded
(
c
,
oidcOAuthRedirectCookie
)
redirectTo
=
sanitizeFrontendRedirectPath
(
redirectTo
)
if
redirectTo
==
""
{
redirectTo
=
oidcOAuthDefaultRedirectTo
}
codeVerifier
:=
""
if
cfg
.
UsePKCE
{
codeVerifier
,
_
=
readCookieDecoded
(
c
,
oidcOAuthVerifierCookie
)
if
codeVerifier
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_verifier"
,
"missing pkce verifier"
,
""
)
return
}
}
expectedNonce
:=
""
if
cfg
.
ValidateIDToken
{
expectedNonce
,
_
=
readCookieDecoded
(
c
,
oidcOAuthNonceCookie
)
if
expectedNonce
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_nonce"
,
"missing oauth nonce"
,
""
)
return
}
}
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
if
redirectURI
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"config_error"
,
"oauth redirect url not configured"
,
""
)
return
}
tokenResp
,
err
:=
oidcExchangeCode
(
c
.
Request
.
Context
(),
cfg
,
code
,
redirectURI
,
codeVerifier
)
if
err
!=
nil
{
description
:=
""
var
exchangeErr
*
oidcTokenExchangeError
if
errors
.
As
(
err
,
&
exchangeErr
)
&&
exchangeErr
!=
nil
{
log
.
Printf
(
"[OIDC OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s"
,
exchangeErr
.
StatusCode
,
exchangeErr
.
ProviderError
,
exchangeErr
.
ProviderDescription
,
truncateLogValue
(
exchangeErr
.
Body
,
2048
),
)
description
=
exchangeErr
.
Error
()
}
else
{
log
.
Printf
(
"[OIDC OAuth] token exchange failed: %v"
,
err
)
description
=
err
.
Error
()
}
redirectOAuthError
(
c
,
frontendCallback
,
"token_exchange_failed"
,
"failed to exchange oauth code"
,
singleLine
(
description
))
return
}
if
cfg
.
ValidateIDToken
&&
strings
.
TrimSpace
(
tokenResp
.
IDToken
)
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_id_token"
,
"missing id_token"
,
""
)
return
}
idClaims
,
err
:=
oidcParseAndValidateIDToken
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
.
IDToken
,
expectedNonce
)
if
err
!=
nil
{
log
.
Printf
(
"[OIDC OAuth] id_token validation failed: %v"
,
err
)
redirectOAuthError
(
c
,
frontendCallback
,
"invalid_id_token"
,
"failed to validate id_token"
,
""
)
return
}
userInfoClaims
,
err
:=
oidcFetchUserInfo
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
)
if
err
!=
nil
{
log
.
Printf
(
"[OIDC OAuth] userinfo fetch failed: %v"
,
err
)
redirectOAuthError
(
c
,
frontendCallback
,
"userinfo_failed"
,
"failed to fetch user info"
,
""
)
return
}
subject
:=
strings
.
TrimSpace
(
idClaims
.
Subject
)
if
subject
==
""
{
subject
=
strings
.
TrimSpace
(
userInfoClaims
.
Subject
)
}
if
subject
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_subject"
,
"missing subject claim"
,
""
)
return
}
issuer
:=
strings
.
TrimSpace
(
idClaims
.
Issuer
)
if
issuer
==
""
{
issuer
=
strings
.
TrimSpace
(
cfg
.
IssuerURL
)
}
if
issuer
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_issuer"
,
"missing issuer claim"
,
""
)
return
}
emailVerified
:=
userInfoClaims
.
EmailVerified
if
emailVerified
==
nil
{
emailVerified
=
idClaims
.
EmailVerified
}
if
cfg
.
RequireEmailVerified
{
if
emailVerified
==
nil
||
!*
emailVerified
{
redirectOAuthError
(
c
,
frontendCallback
,
"email_not_verified"
,
"email is not verified"
,
""
)
return
}
}
identityKey
:=
oidcIdentityKey
(
issuer
,
subject
)
email
:=
oidcSelectLoginEmail
(
userInfoClaims
.
Email
,
idClaims
.
Email
,
identityKey
)
username
:=
firstNonEmpty
(
userInfoClaims
.
Username
,
idClaims
.
PreferredUsername
,
idClaims
.
Name
,
oidcFallbackUsername
(
subject
),
)
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair
,
_
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
""
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrOAuthInvitationRequired
)
{
pendingToken
,
tokenErr
:=
h
.
authService
.
CreatePendingOAuthToken
(
email
,
username
)
if
tokenErr
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
"service_error"
,
""
)
return
}
fragment
:=
url
.
Values
{}
fragment
.
Set
(
"error"
,
"invitation_required"
)
fragment
.
Set
(
"pending_oauth_token"
,
pendingToken
)
fragment
.
Set
(
"redirect"
,
redirectTo
)
redirectWithFragment
(
c
,
frontendCallback
,
fragment
)
return
}
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
fragment
:=
url
.
Values
{}
fragment
.
Set
(
"access_token"
,
tokenPair
.
AccessToken
)
fragment
.
Set
(
"refresh_token"
,
tokenPair
.
RefreshToken
)
fragment
.
Set
(
"expires_in"
,
fmt
.
Sprintf
(
"%d"
,
tokenPair
.
ExpiresIn
))
fragment
.
Set
(
"token_type"
,
"Bearer"
)
fragment
.
Set
(
"redirect"
,
redirectTo
)
redirectWithFragment
(
c
,
frontendCallback
,
fragment
)
}
type
completeOIDCOAuthRequest
struct
{
PendingOAuthToken
string
`json:"pending_oauth_token" binding:"required"`
InvitationCode
string
`json:"invitation_code" binding:"required"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
// the invitation code and creating the user account.
// POST /api/v1/auth/oauth/oidc/complete-registration
func
(
h
*
AuthHandler
)
CompleteOIDCOAuthRegistration
(
c
*
gin
.
Context
)
{
var
req
completeOIDCOAuthRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"INVALID_REQUEST"
,
"message"
:
err
.
Error
()})
return
}
email
,
username
,
err
:=
h
.
authService
.
VerifyPendingOAuthToken
(
req
.
PendingOAuthToken
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusUnauthorized
,
gin
.
H
{
"error"
:
"INVALID_TOKEN"
,
"message"
:
"invalid or expired registration token"
})
return
}
tokenPair
,
_
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
req
.
InvitationCode
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"access_token"
:
tokenPair
.
AccessToken
,
"refresh_token"
:
tokenPair
.
RefreshToken
,
"expires_in"
:
tokenPair
.
ExpiresIn
,
"token_type"
:
"Bearer"
,
})
}
func
(
h
*
AuthHandler
)
getOIDCOAuthConfig
(
ctx
context
.
Context
)
(
config
.
OIDCConnectConfig
,
error
)
{
if
h
!=
nil
&&
h
.
settingSvc
!=
nil
{
return
h
.
settingSvc
.
GetOIDCConnectOAuthConfig
(
ctx
)
}
if
h
==
nil
||
h
.
cfg
==
nil
{
return
config
.
OIDCConnectConfig
{},
infraerrors
.
ServiceUnavailable
(
"CONFIG_NOT_READY"
,
"config not loaded"
)
}
if
!
h
.
cfg
.
OIDC
.
Enabled
{
return
config
.
OIDCConnectConfig
{},
infraerrors
.
NotFound
(
"OAUTH_DISABLED"
,
"oauth login is disabled"
)
}
return
h
.
cfg
.
OIDC
,
nil
}
func
oidcExchangeCode
(
ctx
context
.
Context
,
cfg
config
.
OIDCConnectConfig
,
code
string
,
redirectURI
string
,
codeVerifier
string
,
)
(
*
oidcTokenResponse
,
error
)
{
client
:=
req
.
C
()
.
SetTimeout
(
30
*
time
.
Second
)
form
:=
url
.
Values
{}
form
.
Set
(
"grant_type"
,
"authorization_code"
)
form
.
Set
(
"client_id"
,
cfg
.
ClientID
)
form
.
Set
(
"code"
,
code
)
form
.
Set
(
"redirect_uri"
,
redirectURI
)
if
cfg
.
UsePKCE
{
form
.
Set
(
"code_verifier"
,
codeVerifier
)
}
r
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"Accept"
,
"application/json"
)
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
cfg
.
TokenAuthMethod
))
{
case
""
,
"client_secret_post"
:
form
.
Set
(
"client_secret"
,
cfg
.
ClientSecret
)
case
"client_secret_basic"
:
r
.
SetBasicAuth
(
cfg
.
ClientID
,
cfg
.
ClientSecret
)
case
"none"
:
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported token_auth_method: %s"
,
cfg
.
TokenAuthMethod
)
}
resp
,
err
:=
r
.
SetFormDataFromValues
(
form
)
.
Post
(
cfg
.
TokenURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request token: %w"
,
err
)
}
body
:=
strings
.
TrimSpace
(
resp
.
String
())
if
!
resp
.
IsSuccessState
()
{
providerErr
,
providerDesc
:=
parseOAuthProviderError
(
body
)
return
nil
,
&
oidcTokenExchangeError
{
StatusCode
:
resp
.
StatusCode
,
ProviderError
:
providerErr
,
ProviderDescription
:
providerDesc
,
Body
:
body
,
}
}
tokenResp
,
ok
:=
oidcParseTokenResponse
(
body
)
if
!
ok
{
return
nil
,
&
oidcTokenExchangeError
{
StatusCode
:
resp
.
StatusCode
,
Body
:
body
}
}
if
strings
.
TrimSpace
(
tokenResp
.
TokenType
)
==
""
{
tokenResp
.
TokenType
=
"Bearer"
}
if
strings
.
TrimSpace
(
tokenResp
.
AccessToken
)
==
""
&&
strings
.
TrimSpace
(
tokenResp
.
IDToken
)
==
""
{
return
nil
,
&
oidcTokenExchangeError
{
StatusCode
:
resp
.
StatusCode
,
Body
:
body
}
}
return
tokenResp
,
nil
}
func
oidcParseTokenResponse
(
body
string
)
(
*
oidcTokenResponse
,
bool
)
{
body
=
strings
.
TrimSpace
(
body
)
if
body
==
""
{
return
nil
,
false
}
accessToken
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"access_token"
))
idToken
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"id_token"
))
if
accessToken
!=
""
||
idToken
!=
""
{
tokenType
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"token_type"
))
refreshToken
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"refresh_token"
))
scope
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"scope"
))
expiresIn
:=
gjson
.
Get
(
body
,
"expires_in"
)
.
Int
()
return
&
oidcTokenResponse
{
AccessToken
:
accessToken
,
TokenType
:
tokenType
,
ExpiresIn
:
expiresIn
,
RefreshToken
:
refreshToken
,
Scope
:
scope
,
IDToken
:
idToken
,
},
true
}
values
,
err
:=
url
.
ParseQuery
(
body
)
if
err
!=
nil
{
return
nil
,
false
}
accessToken
=
strings
.
TrimSpace
(
values
.
Get
(
"access_token"
))
idToken
=
strings
.
TrimSpace
(
values
.
Get
(
"id_token"
))
if
accessToken
==
""
&&
idToken
==
""
{
return
nil
,
false
}
expiresIn
:=
int64
(
0
)
if
raw
:=
strings
.
TrimSpace
(
values
.
Get
(
"expires_in"
));
raw
!=
""
{
if
v
,
parseErr
:=
strconv
.
ParseInt
(
raw
,
10
,
64
);
parseErr
==
nil
{
expiresIn
=
v
}
}
return
&
oidcTokenResponse
{
AccessToken
:
accessToken
,
TokenType
:
strings
.
TrimSpace
(
values
.
Get
(
"token_type"
)),
ExpiresIn
:
expiresIn
,
RefreshToken
:
strings
.
TrimSpace
(
values
.
Get
(
"refresh_token"
)),
Scope
:
strings
.
TrimSpace
(
values
.
Get
(
"scope"
)),
IDToken
:
idToken
,
},
true
}
func
oidcFetchUserInfo
(
ctx
context
.
Context
,
cfg
config
.
OIDCConnectConfig
,
token
*
oidcTokenResponse
,
)
(
*
oidcUserInfoClaims
,
error
)
{
if
strings
.
TrimSpace
(
cfg
.
UserInfoURL
)
==
""
{
return
&
oidcUserInfoClaims
{},
nil
}
if
token
==
nil
||
strings
.
TrimSpace
(
token
.
AccessToken
)
==
""
{
return
nil
,
errors
.
New
(
"missing access_token for userinfo request"
)
}
client
:=
req
.
C
()
.
SetTimeout
(
30
*
time
.
Second
)
authorization
,
err
:=
buildBearerAuthorization
(
token
.
TokenType
,
token
.
AccessToken
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid token for userinfo request: %w"
,
err
)
}
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"Accept"
,
"application/json"
)
.
SetHeader
(
"Authorization"
,
authorization
)
.
Get
(
cfg
.
UserInfoURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request userinfo: %w"
,
err
)
}
if
!
resp
.
IsSuccessState
()
{
return
nil
,
fmt
.
Errorf
(
"userinfo status=%d"
,
resp
.
StatusCode
)
}
return
oidcParseUserInfo
(
resp
.
String
(),
cfg
),
nil
}
func
oidcParseUserInfo
(
body
string
,
cfg
config
.
OIDCConnectConfig
)
*
oidcUserInfoClaims
{
claims
:=
&
oidcUserInfoClaims
{}
claims
.
Email
=
firstNonEmpty
(
getGJSON
(
body
,
cfg
.
UserInfoEmailPath
),
getGJSON
(
body
,
"email"
),
getGJSON
(
body
,
"user.email"
),
getGJSON
(
body
,
"data.email"
),
getGJSON
(
body
,
"attributes.email"
),
)
claims
.
Username
=
firstNonEmpty
(
getGJSON
(
body
,
cfg
.
UserInfoUsernamePath
),
getGJSON
(
body
,
"preferred_username"
),
getGJSON
(
body
,
"username"
),
getGJSON
(
body
,
"name"
),
getGJSON
(
body
,
"user.username"
),
getGJSON
(
body
,
"user.name"
),
)
claims
.
Subject
=
firstNonEmpty
(
getGJSON
(
body
,
cfg
.
UserInfoIDPath
),
getGJSON
(
body
,
"sub"
),
getGJSON
(
body
,
"id"
),
getGJSON
(
body
,
"user_id"
),
getGJSON
(
body
,
"uid"
),
getGJSON
(
body
,
"user.id"
),
)
if
verified
,
ok
:=
getGJSONBool
(
body
,
"email_verified"
);
ok
{
claims
.
EmailVerified
=
&
verified
}
claims
.
Email
=
strings
.
TrimSpace
(
claims
.
Email
)
claims
.
Username
=
strings
.
TrimSpace
(
claims
.
Username
)
claims
.
Subject
=
strings
.
TrimSpace
(
claims
.
Subject
)
return
claims
}
func
getGJSONBool
(
body
string
,
path
string
)
(
bool
,
bool
)
{
path
=
strings
.
TrimSpace
(
path
)
if
path
==
""
{
return
false
,
false
}
res
:=
gjson
.
Get
(
body
,
path
)
if
!
res
.
Exists
()
{
return
false
,
false
}
return
res
.
Bool
(),
true
}
func
buildOIDCAuthorizeURL
(
cfg
config
.
OIDCConnectConfig
,
state
,
nonce
,
codeChallenge
,
redirectURI
string
)
(
string
,
error
)
{
u
,
err
:=
url
.
Parse
(
cfg
.
AuthorizeURL
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"parse authorize_url: %w"
,
err
)
}
q
:=
u
.
Query
()
q
.
Set
(
"response_type"
,
"code"
)
q
.
Set
(
"client_id"
,
cfg
.
ClientID
)
q
.
Set
(
"redirect_uri"
,
redirectURI
)
if
strings
.
TrimSpace
(
cfg
.
Scopes
)
!=
""
{
q
.
Set
(
"scope"
,
cfg
.
Scopes
)
}
q
.
Set
(
"state"
,
state
)
if
strings
.
TrimSpace
(
nonce
)
!=
""
{
q
.
Set
(
"nonce"
,
nonce
)
}
if
cfg
.
UsePKCE
{
q
.
Set
(
"code_challenge"
,
codeChallenge
)
q
.
Set
(
"code_challenge_method"
,
"S256"
)
}
u
.
RawQuery
=
q
.
Encode
()
return
u
.
String
(),
nil
}
func
oidcParseAndValidateIDToken
(
ctx
context
.
Context
,
cfg
config
.
OIDCConnectConfig
,
idToken
string
,
expectedNonce
string
)
(
*
oidcIDTokenClaims
,
error
)
{
idToken
=
strings
.
TrimSpace
(
idToken
)
if
idToken
==
""
{
return
nil
,
errors
.
New
(
"missing id_token"
)
}
allowed
:=
oidcAllowedSigningAlgs
(
cfg
.
AllowedSigningAlgs
)
if
len
(
allowed
)
==
0
{
return
nil
,
errors
.
New
(
"empty allowed signing algorithms"
)
}
jwks
,
err
:=
oidcFetchJWKSet
(
ctx
,
cfg
.
JWKSURL
)
if
err
!=
nil
{
return
nil
,
err
}
leeway
:=
time
.
Duration
(
cfg
.
ClockSkewSeconds
)
*
time
.
Second
claims
:=
&
oidcIDTokenClaims
{}
parsed
,
err
:=
jwt
.
ParseWithClaims
(
idToken
,
claims
,
func
(
token
*
jwt
.
Token
)
(
any
,
error
)
{
alg
:=
strings
.
TrimSpace
(
token
.
Method
.
Alg
())
if
!
containsString
(
allowed
,
alg
)
{
return
nil
,
fmt
.
Errorf
(
"unexpected signing algorithm: %s"
,
alg
)
}
kid
,
_
:=
token
.
Header
[
"kid"
]
.
(
string
)
return
oidcFindPublicKey
(
jwks
,
strings
.
TrimSpace
(
kid
),
alg
)
},
jwt
.
WithValidMethods
(
allowed
),
jwt
.
WithAudience
(
cfg
.
ClientID
),
jwt
.
WithIssuer
(
cfg
.
IssuerURL
),
jwt
.
WithLeeway
(
leeway
),
)
if
err
!=
nil
{
return
nil
,
err
}
if
!
parsed
.
Valid
{
return
nil
,
errors
.
New
(
"id_token invalid"
)
}
if
strings
.
TrimSpace
(
claims
.
Subject
)
==
""
{
return
nil
,
errors
.
New
(
"id_token missing sub"
)
}
if
expectedNonce
!=
""
&&
strings
.
TrimSpace
(
claims
.
Nonce
)
!=
strings
.
TrimSpace
(
expectedNonce
)
{
return
nil
,
errors
.
New
(
"id_token nonce mismatch"
)
}
if
len
(
claims
.
Audience
)
>
1
{
if
strings
.
TrimSpace
(
claims
.
Azp
)
==
""
||
strings
.
TrimSpace
(
claims
.
Azp
)
!=
strings
.
TrimSpace
(
cfg
.
ClientID
)
{
return
nil
,
errors
.
New
(
"id_token azp mismatch"
)
}
}
return
claims
,
nil
}
func
oidcAllowedSigningAlgs
(
raw
string
)
[]
string
{
if
strings
.
TrimSpace
(
raw
)
==
""
{
return
[]
string
{
"RS256"
,
"ES256"
,
"PS256"
}
}
seen
:=
make
(
map
[
string
]
struct
{})
out
:=
make
([]
string
,
0
,
4
)
for
_
,
part
:=
range
strings
.
Split
(
raw
,
","
)
{
alg
:=
strings
.
ToUpper
(
strings
.
TrimSpace
(
part
))
if
alg
==
""
{
continue
}
if
_
,
ok
:=
seen
[
alg
];
ok
{
continue
}
seen
[
alg
]
=
struct
{}{}
out
=
append
(
out
,
alg
)
}
return
out
}
func
oidcFetchJWKSet
(
ctx
context
.
Context
,
jwksURL
string
)
(
*
oidcJWKSet
,
error
)
{
jwksURL
=
strings
.
TrimSpace
(
jwksURL
)
if
jwksURL
==
""
{
return
nil
,
errors
.
New
(
"missing jwks_url"
)
}
resp
,
err
:=
req
.
C
()
.
SetTimeout
(
30
*
time
.
Second
)
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"Accept"
,
"application/json"
)
.
Get
(
jwksURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request jwks: %w"
,
err
)
}
if
!
resp
.
IsSuccessState
()
{
return
nil
,
fmt
.
Errorf
(
"jwks status=%d"
,
resp
.
StatusCode
)
}
set
:=
&
oidcJWKSet
{}
if
err
:=
json
.
Unmarshal
(
resp
.
Bytes
(),
set
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse jwks: %w"
,
err
)
}
if
len
(
set
.
Keys
)
==
0
{
return
nil
,
errors
.
New
(
"jwks empty keys"
)
}
return
set
,
nil
}
func
oidcFindPublicKey
(
set
*
oidcJWKSet
,
kid
,
alg
string
)
(
any
,
error
)
{
if
set
==
nil
{
return
nil
,
errors
.
New
(
"jwks not loaded"
)
}
alg
=
strings
.
ToUpper
(
strings
.
TrimSpace
(
alg
))
kid
=
strings
.
TrimSpace
(
kid
)
var
lastErr
error
for
i
:=
range
set
.
Keys
{
k
:=
set
.
Keys
[
i
]
if
strings
.
TrimSpace
(
k
.
Use
)
!=
""
&&
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
k
.
Use
),
"sig"
)
{
continue
}
if
kid
!=
""
&&
strings
.
TrimSpace
(
k
.
Kid
)
!=
kid
{
continue
}
if
strings
.
TrimSpace
(
k
.
Alg
)
!=
""
&&
!
strings
.
EqualFold
(
strings
.
TrimSpace
(
k
.
Alg
),
alg
)
{
continue
}
pk
,
err
:=
k
.
publicKey
()
if
err
!=
nil
{
lastErr
=
err
continue
}
if
pk
!=
nil
{
return
pk
,
nil
}
}
if
lastErr
!=
nil
{
return
nil
,
lastErr
}
if
kid
!=
""
{
return
nil
,
fmt
.
Errorf
(
"jwk not found for kid=%s"
,
kid
)
}
return
nil
,
errors
.
New
(
"jwk not found"
)
}
func
(
k
oidcJWK
)
publicKey
()
(
any
,
error
)
{
switch
strings
.
ToUpper
(
strings
.
TrimSpace
(
k
.
Kty
))
{
case
"RSA"
:
n
,
err
:=
decodeBase64URLBigInt
(
k
.
N
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"decode rsa n: %w"
,
err
)
}
eBytes
,
err
:=
base64
.
RawURLEncoding
.
DecodeString
(
strings
.
TrimSpace
(
k
.
E
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"decode rsa e: %w"
,
err
)
}
if
len
(
eBytes
)
==
0
{
return
nil
,
errors
.
New
(
"empty rsa e"
)
}
e
:=
0
for
_
,
b
:=
range
eBytes
{
e
=
(
e
<<
8
)
|
int
(
b
)
}
if
e
<=
0
{
return
nil
,
errors
.
New
(
"invalid rsa exponent"
)
}
if
n
.
Sign
()
<=
0
{
return
nil
,
errors
.
New
(
"invalid rsa modulus"
)
}
return
&
rsa
.
PublicKey
{
N
:
n
,
E
:
e
},
nil
case
"EC"
:
var
curve
elliptic
.
Curve
switch
strings
.
TrimSpace
(
k
.
Crv
)
{
case
"P-256"
:
curve
=
elliptic
.
P256
()
case
"P-384"
:
curve
=
elliptic
.
P384
()
case
"P-521"
:
curve
=
elliptic
.
P521
()
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported ec curve: %s"
,
k
.
Crv
)
}
x
,
err
:=
decodeBase64URLBigInt
(
k
.
X
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"decode ec x: %w"
,
err
)
}
y
,
err
:=
decodeBase64URLBigInt
(
k
.
Y
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"decode ec y: %w"
,
err
)
}
if
!
curve
.
IsOnCurve
(
x
,
y
)
{
return
nil
,
errors
.
New
(
"ec point is not on curve"
)
}
return
&
ecdsa
.
PublicKey
{
Curve
:
curve
,
X
:
x
,
Y
:
y
},
nil
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported jwk kty: %s"
,
k
.
Kty
)
}
}
func
decodeBase64URLBigInt
(
raw
string
)
(
*
big
.
Int
,
error
)
{
buf
,
err
:=
base64
.
RawURLEncoding
.
DecodeString
(
strings
.
TrimSpace
(
raw
))
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
buf
)
==
0
{
return
nil
,
errors
.
New
(
"empty value"
)
}
return
new
(
big
.
Int
)
.
SetBytes
(
buf
),
nil
}
func
containsString
(
values
[]
string
,
target
string
)
bool
{
target
=
strings
.
TrimSpace
(
target
)
for
_
,
v
:=
range
values
{
if
strings
.
EqualFold
(
strings
.
TrimSpace
(
v
),
target
)
{
return
true
}
}
return
false
}
func
oidcIdentityKey
(
issuer
,
subject
string
)
string
{
issuer
=
strings
.
TrimSpace
(
strings
.
ToLower
(
issuer
))
subject
=
strings
.
TrimSpace
(
subject
)
return
issuer
+
"
\x1f
"
+
subject
}
func
oidcSyntheticEmailFromIdentityKey
(
identityKey
string
)
string
{
identityKey
=
strings
.
TrimSpace
(
identityKey
)
if
identityKey
==
""
{
return
""
}
sum
:=
sha256
.
Sum256
([]
byte
(
identityKey
))
return
"oidc-"
+
hex
.
EncodeToString
(
sum
[
:
16
])
+
service
.
OIDCConnectSyntheticEmailDomain
}
func
oidcSelectLoginEmail
(
userInfoEmail
,
idTokenEmail
,
identityKey
string
)
string
{
email
:=
strings
.
TrimSpace
(
firstNonEmpty
(
userInfoEmail
,
idTokenEmail
))
if
email
!=
""
{
return
email
}
return
oidcSyntheticEmailFromIdentityKey
(
identityKey
)
}
func
oidcFallbackUsername
(
subject
string
)
string
{
subject
=
strings
.
TrimSpace
(
subject
)
if
subject
==
""
{
return
"oidc_user"
}
sum
:=
sha256
.
Sum256
([]
byte
(
subject
))
return
"oidc_"
+
hex
.
EncodeToString
(
sum
[
:
])[
:
12
]
}
func
oidcSetCookie
(
c
*
gin
.
Context
,
name
,
value
string
,
maxAgeSec
int
,
secure
bool
)
{
http
.
SetCookie
(
c
.
Writer
,
&
http
.
Cookie
{
Name
:
name
,
Value
:
value
,
Path
:
oidcOAuthCookiePath
,
MaxAge
:
maxAgeSec
,
HttpOnly
:
true
,
Secure
:
secure
,
SameSite
:
http
.
SameSiteLaxMode
,
})
}
func
oidcClearCookie
(
c
*
gin
.
Context
,
name
string
,
secure
bool
)
{
http
.
SetCookie
(
c
.
Writer
,
&
http
.
Cookie
{
Name
:
name
,
Value
:
""
,
Path
:
oidcOAuthCookiePath
,
MaxAge
:
-
1
,
HttpOnly
:
true
,
Secure
:
secure
,
SameSite
:
http
.
SameSiteLaxMode
,
})
}
backend/internal/handler/auth_oidc_oauth_test.go
0 → 100644
View file @
2b70d1d3
package
handler
import
(
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"math/big"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func
TestOIDCSyntheticEmailStableAndDistinct
(
t
*
testing
.
T
)
{
k1
:=
oidcIdentityKey
(
"https://issuer.example.com"
,
"subject-a"
)
k2
:=
oidcIdentityKey
(
"https://issuer.example.com"
,
"subject-b"
)
e1
:=
oidcSyntheticEmailFromIdentityKey
(
k1
)
e1Again
:=
oidcSyntheticEmailFromIdentityKey
(
k1
)
e2
:=
oidcSyntheticEmailFromIdentityKey
(
k2
)
require
.
Equal
(
t
,
e1
,
e1Again
)
require
.
NotEqual
(
t
,
e1
,
e2
)
require
.
Contains
(
t
,
e1
,
"@oidc-connect.invalid"
)
}
func
TestOIDCSelectLoginEmailPrefersRealEmail
(
t
*
testing
.
T
)
{
identityKey
:=
oidcIdentityKey
(
"https://issuer.example.com"
,
"subject-a"
)
email
:=
oidcSelectLoginEmail
(
"user@example.com"
,
"idtoken@example.com"
,
identityKey
)
require
.
Equal
(
t
,
"user@example.com"
,
email
)
email
=
oidcSelectLoginEmail
(
""
,
"idtoken@example.com"
,
identityKey
)
require
.
Equal
(
t
,
"idtoken@example.com"
,
email
)
email
=
oidcSelectLoginEmail
(
""
,
""
,
identityKey
)
require
.
Contains
(
t
,
email
,
"@oidc-connect.invalid"
)
require
.
Equal
(
t
,
oidcSyntheticEmailFromIdentityKey
(
identityKey
),
email
)
}
func
TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE
(
t
*
testing
.
T
)
{
cfg
:=
config
.
OIDCConnectConfig
{
AuthorizeURL
:
"https://issuer.example.com/auth"
,
ClientID
:
"cid"
,
Scopes
:
"openid email profile"
,
UsePKCE
:
true
,
}
u
,
err
:=
buildOIDCAuthorizeURL
(
cfg
,
"state123"
,
"nonce123"
,
"challenge123"
,
"https://app.example.com/callback"
)
require
.
NoError
(
t
,
err
)
require
.
Contains
(
t
,
u
,
"nonce=nonce123"
)
require
.
Contains
(
t
,
u
,
"code_challenge=challenge123"
)
require
.
Contains
(
t
,
u
,
"code_challenge_method=S256"
)
require
.
Contains
(
t
,
u
,
"scope=openid+email+profile"
)
}
func
TestOIDCParseAndValidateIDToken
(
t
*
testing
.
T
)
{
priv
,
err
:=
rsa
.
GenerateKey
(
rand
.
Reader
,
2048
)
require
.
NoError
(
t
,
err
)
kid
:=
"kid-1"
jwks
:=
oidcJWKSet
{
Keys
:
[]
oidcJWK
{
buildRSAJWK
(
kid
,
&
priv
.
PublicKey
)}}
srv
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
require
.
NoError
(
t
,
json
.
NewEncoder
(
w
)
.
Encode
(
jwks
))
}))
defer
srv
.
Close
()
now
:=
time
.
Now
()
claims
:=
oidcIDTokenClaims
{
Nonce
:
"nonce-ok"
,
Azp
:
"client-1"
,
RegisteredClaims
:
jwt
.
RegisteredClaims
{
Issuer
:
"https://issuer.example.com"
,
Subject
:
"subject-1"
,
Audience
:
jwt
.
ClaimStrings
{
"client-1"
,
"another-aud"
},
IssuedAt
:
jwt
.
NewNumericDate
(
now
),
NotBefore
:
jwt
.
NewNumericDate
(
now
.
Add
(
-
30
*
time
.
Second
)),
ExpiresAt
:
jwt
.
NewNumericDate
(
now
.
Add
(
5
*
time
.
Minute
)),
},
}
tok
:=
jwt
.
NewWithClaims
(
jwt
.
SigningMethodRS256
,
claims
)
tok
.
Header
[
"kid"
]
=
kid
signed
,
err
:=
tok
.
SignedString
(
priv
)
require
.
NoError
(
t
,
err
)
cfg
:=
config
.
OIDCConnectConfig
{
ClientID
:
"client-1"
,
IssuerURL
:
"https://issuer.example.com"
,
JWKSURL
:
srv
.
URL
,
AllowedSigningAlgs
:
"RS256"
,
ClockSkewSeconds
:
120
,
}
parsed
,
err
:=
oidcParseAndValidateIDToken
(
context
.
Background
(),
cfg
,
signed
,
"nonce-ok"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"subject-1"
,
parsed
.
Subject
)
require
.
Equal
(
t
,
"https://issuer.example.com"
,
parsed
.
Issuer
)
_
,
err
=
oidcParseAndValidateIDToken
(
context
.
Background
(),
cfg
,
signed
,
"bad-nonce"
)
require
.
Error
(
t
,
err
)
}
func
buildRSAJWK
(
kid
string
,
pub
*
rsa
.
PublicKey
)
oidcJWK
{
n
:=
base64
.
RawURLEncoding
.
EncodeToString
(
pub
.
N
.
Bytes
())
e
:=
base64
.
RawURLEncoding
.
EncodeToString
(
big
.
NewInt
(
int64
(
pub
.
E
))
.
Bytes
())
return
oidcJWK
{
Kty
:
"RSA"
,
Kid
:
kid
,
Use
:
"sig"
,
Alg
:
"RS256"
,
N
:
n
,
E
:
e
,
}
}
backend/internal/handler/dto/mappers.go
View file @
2b70d1d3
...
@@ -138,6 +138,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
...
@@ -138,6 +138,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
ModelRoutingEnabled
:
g
.
ModelRoutingEnabled
,
ModelRoutingEnabled
:
g
.
ModelRoutingEnabled
,
MCPXMLInject
:
g
.
MCPXMLInject
,
MCPXMLInject
:
g
.
MCPXMLInject
,
DefaultMappedModel
:
g
.
DefaultMappedModel
,
DefaultMappedModel
:
g
.
DefaultMappedModel
,
MessagesDispatchModelConfig
:
g
.
MessagesDispatchModelConfig
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
AccountCount
:
g
.
AccountCount
,
AccountCount
:
g
.
AccountCount
,
ActiveAccountCount
:
g
.
ActiveAccountCount
,
ActiveAccountCount
:
g
.
ActiveAccountCount
,
...
...
backend/internal/handler/dto/settings.go
View file @
2b70d1d3
...
@@ -51,6 +51,29 @@ type SystemSettings struct {
...
@@ -51,6 +51,29 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured
bool
`json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectClientSecretConfigured
bool
`json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL
string
`json:"linuxdo_connect_redirect_url"`
LinuxDoConnectRedirectURL
string
`json:"linuxdo_connect_redirect_url"`
OIDCConnectEnabled
bool
`json:"oidc_connect_enabled"`
OIDCConnectProviderName
string
`json:"oidc_connect_provider_name"`
OIDCConnectClientID
string
`json:"oidc_connect_client_id"`
OIDCConnectClientSecretConfigured
bool
`json:"oidc_connect_client_secret_configured"`
OIDCConnectIssuerURL
string
`json:"oidc_connect_issuer_url"`
OIDCConnectDiscoveryURL
string
`json:"oidc_connect_discovery_url"`
OIDCConnectAuthorizeURL
string
`json:"oidc_connect_authorize_url"`
OIDCConnectTokenURL
string
`json:"oidc_connect_token_url"`
OIDCConnectUserInfoURL
string
`json:"oidc_connect_userinfo_url"`
OIDCConnectJWKSURL
string
`json:"oidc_connect_jwks_url"`
OIDCConnectScopes
string
`json:"oidc_connect_scopes"`
OIDCConnectRedirectURL
string
`json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL
string
`json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod
string
`json:"oidc_connect_token_auth_method"`
OIDCConnectUsePKCE
bool
`json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken
bool
`json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs
string
`json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds
int
`json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified
bool
`json:"oidc_connect_require_email_verified"`
OIDCConnectUserInfoEmailPath
string
`json:"oidc_connect_userinfo_email_path"`
OIDCConnectUserInfoIDPath
string
`json:"oidc_connect_userinfo_id_path"`
OIDCConnectUserInfoUsernamePath
string
`json:"oidc_connect_userinfo_username_path"`
SiteName
string
`json:"site_name"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteLogo
string
`json:"site_logo"`
SiteSubtitle
string
`json:"site_subtitle"`
SiteSubtitle
string
`json:"site_subtitle"`
...
@@ -132,6 +155,9 @@ type PublicSettings struct {
...
@@ -132,6 +155,9 @@ type PublicSettings struct {
CustomMenuItems
[]
CustomMenuItem
`json:"custom_menu_items"`
CustomMenuItems
[]
CustomMenuItem
`json:"custom_menu_items"`
CustomEndpoints
[]
CustomEndpoint
`json:"custom_endpoints"`
CustomEndpoints
[]
CustomEndpoint
`json:"custom_endpoints"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
OIDCOAuthEnabled
bool
`json:"oidc_oauth_enabled"`
OIDCOAuthProviderName
string
`json:"oidc_oauth_provider_name"`
SoraClientEnabled
bool
`json:"sora_client_enabled"`
BackendModeEnabled
bool
`json:"backend_mode_enabled"`
BackendModeEnabled
bool
`json:"backend_mode_enabled"`
Version
string
`json:"version"`
Version
string
`json:"version"`
}
}
...
...
backend/internal/handler/dto/types.go
View file @
2b70d1d3
package
dto
package
dto
import
"time"
import
(
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
type
User
struct
{
type
User
struct
{
ID
int64
`json:"id"`
ID
int64
`json:"id"`
...
@@ -113,6 +117,7 @@ type AdminGroup struct {
...
@@ -113,6 +117,7 @@ type AdminGroup struct {
// OpenAI Messages 调度配置(仅 openai 平台使用)
// OpenAI Messages 调度配置(仅 openai 平台使用)
DefaultMappedModel
string
`json:"default_mapped_model"`
DefaultMappedModel
string
`json:"default_mapped_model"`
MessagesDispatchModelConfig
domain
.
OpenAIMessagesDispatchModelConfig
`json:"messages_dispatch_model_config"`
// 支持的模型系列(仅 antigravity 平台使用)
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
`json:"supported_model_scopes"`
SupportedModelScopes
[]
string
`json:"supported_model_scopes"`
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
2b70d1d3
...
@@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode
...
@@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode
return
strings
.
TrimSpace
(
apiKey
.
Group
.
DefaultMappedModel
)
return
strings
.
TrimSpace
(
apiKey
.
Group
.
DefaultMappedModel
)
}
}
func
resolveOpenAIMessagesDispatchMappedModel
(
apiKey
*
service
.
APIKey
,
requestedModel
string
)
string
{
if
apiKey
==
nil
||
apiKey
.
Group
==
nil
{
return
""
}
return
strings
.
TrimSpace
(
apiKey
.
Group
.
ResolveMessagesDispatchModel
(
requestedModel
))
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func
NewOpenAIGatewayHandler
(
func
NewOpenAIGatewayHandler
(
gatewayService
*
service
.
OpenAIGatewayService
,
gatewayService
*
service
.
OpenAIGatewayService
,
...
@@ -551,6 +558,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
...
@@ -551,6 +558,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
}
}
reqModel
:=
modelResult
.
String
()
reqModel
:=
modelResult
.
String
()
routingModel
:=
service
.
NormalizeOpenAICompatRequestedModel
(
reqModel
)
routingModel
:=
service
.
NormalizeOpenAICompatRequestedModel
(
reqModel
)
preferredMappedModel
:=
resolveOpenAIMessagesDispatchMappedModel
(
apiKey
,
reqModel
)
reqStream
:=
gjson
.
GetBytes
(
body
,
"stream"
)
.
Bool
()
reqStream
:=
gjson
.
GetBytes
(
body
,
"stream"
)
.
Bool
()
reqLog
=
reqLog
.
With
(
zap
.
String
(
"model"
,
reqModel
),
zap
.
Bool
(
"stream"
,
reqStream
))
reqLog
=
reqLog
.
With
(
zap
.
String
(
"model"
,
reqModel
),
zap
.
Bool
(
"stream"
,
reqStream
))
...
@@ -609,17 +617,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
...
@@ -609,17 +617,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
sameAccountRetryCount
:=
make
(
map
[
int64
]
int
)
sameAccountRetryCount
:=
make
(
map
[
int64
]
int
)
var
lastFailoverErr
*
service
.
UpstreamFailoverError
var
lastFailoverErr
*
service
.
UpstreamFailoverError
effectiveMappedModel
:=
preferredMappedModel
for
{
for
{
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
currentRoutingModel
:=
routingModel
c
.
Set
(
"openai_messages_fallback_model"
,
""
)
if
effectiveMappedModel
!=
""
{
currentRoutingModel
=
effectiveMappedModel
}
reqLog
.
Debug
(
"openai_messages.account_selecting"
,
zap
.
Int
(
"excluded_account_count"
,
len
(
failedAccountIDs
)))
reqLog
.
Debug
(
"openai_messages.account_selecting"
,
zap
.
Int
(
"excluded_account_count"
,
len
(
failedAccountIDs
)))
selection
,
scheduleDecision
,
err
:=
h
.
gatewayService
.
SelectAccountWithScheduler
(
selection
,
scheduleDecision
,
err
:=
h
.
gatewayService
.
SelectAccountWithScheduler
(
c
.
Request
.
Context
(),
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
apiKey
.
GroupID
,
""
,
// no previous_response_id
""
,
// no previous_response_id
sessionHash
,
sessionHash
,
r
outingModel
,
currentR
outingModel
,
failedAccountIDs
,
failedAccountIDs
,
service
.
OpenAIUpstreamTransportAny
,
service
.
OpenAIUpstreamTransportAny
,
)
)
...
@@ -628,29 +639,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
...
@@ -628,29 +639,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
zap
.
Error
(
err
),
zap
.
Error
(
err
),
zap
.
Int
(
"excluded_account_count"
,
len
(
failedAccountIDs
)),
zap
.
Int
(
"excluded_account_count"
,
len
(
failedAccountIDs
)),
)
)
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
failedAccountIDs
)
==
0
{
defaultModel
:=
""
if
apiKey
.
Group
!=
nil
{
defaultModel
=
apiKey
.
Group
.
DefaultMappedModel
}
if
defaultModel
!=
""
&&
defaultModel
!=
routingModel
{
reqLog
.
Info
(
"openai_messages.fallback_to_default_model"
,
zap
.
String
(
"default_mapped_model"
,
defaultModel
),
)
selection
,
scheduleDecision
,
err
=
h
.
gatewayService
.
SelectAccountWithScheduler
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
""
,
sessionHash
,
defaultModel
,
failedAccountIDs
,
service
.
OpenAIUpstreamTransportAny
,
)
if
err
==
nil
&&
selection
!=
nil
{
c
.
Set
(
"openai_messages_fallback_model"
,
defaultModel
)
}
}
if
err
!=
nil
{
if
err
!=
nil
{
h
.
anthropicStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"Service temporarily unavailable"
,
streamStarted
)
h
.
anthropicStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"Service temporarily unavailable"
,
streamStarted
)
return
return
...
@@ -682,9 +671,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
...
@@ -682,9 +671,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service
.
SetOpsLatencyMs
(
c
,
service
.
OpsRoutingLatencyMsKey
,
time
.
Since
(
routingStart
)
.
Milliseconds
())
service
.
SetOpsLatencyMs
(
c
,
service
.
OpsRoutingLatencyMsKey
,
time
.
Since
(
routingStart
)
.
Milliseconds
())
forwardStart
:=
time
.
Now
()
forwardStart
:=
time
.
Now
()
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
defaultMappedModel
:=
strings
.
TrimSpace
(
effectiveMappedModel
)
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel
:=
resolveOpenAIForwardDefaultMappedModel
(
apiKey
,
c
.
GetString
(
"openai_messages_fallback_model"
))
// 应用渠道模型映射到请求体
// 应用渠道模型映射到请求体
forwardBody
:=
body
forwardBody
:=
body
if
channelMappingMsg
.
Mapped
{
if
channelMappingMsg
.
Mapped
{
...
...
backend/internal/handler/openai_gateway_handler_test.go
View file @
2b70d1d3
...
@@ -360,7 +360,7 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
...
@@ -360,7 +360,7 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
require
.
Equal
(
t
,
"gpt-5.2"
,
resolveOpenAIForwardDefaultMappedModel
(
apiKey
,
" gpt-5.2 "
))
require
.
Equal
(
t
,
"gpt-5.2"
,
resolveOpenAIForwardDefaultMappedModel
(
apiKey
,
" gpt-5.2 "
))
})
})
t
.
Run
(
"uses_group_default_
on_normal_path
"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"uses_group_default_
when_explicit_fallback_absent
"
,
func
(
t
*
testing
.
T
)
{
apiKey
:=
&
service
.
APIKey
{
apiKey
:=
&
service
.
APIKey
{
Group
:
&
service
.
Group
{
DefaultMappedModel
:
"gpt-5.4"
},
Group
:
&
service
.
Group
{
DefaultMappedModel
:
"gpt-5.4"
},
}
}
...
@@ -376,6 +376,45 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
...
@@ -376,6 +376,45 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
})
})
}
}
func
TestResolveOpenAIMessagesDispatchMappedModel
(
t
*
testing
.
T
)
{
t
.
Run
(
"exact_claude_model_override_wins"
,
func
(
t
*
testing
.
T
)
{
apiKey
:=
&
service
.
APIKey
{
Group
:
&
service
.
Group
{
MessagesDispatchModelConfig
:
service
.
OpenAIMessagesDispatchModelConfig
{
SonnetMappedModel
:
"gpt-5.2"
,
ExactModelMappings
:
map
[
string
]
string
{
"claude-sonnet-4-5-20250929"
:
"gpt-5.4-mini-high"
,
},
},
},
}
require
.
Equal
(
t
,
"gpt-5.4-mini"
,
resolveOpenAIMessagesDispatchMappedModel
(
apiKey
,
"claude-sonnet-4-5-20250929"
))
})
t
.
Run
(
"uses_family_default_when_no_override"
,
func
(
t
*
testing
.
T
)
{
apiKey
:=
&
service
.
APIKey
{
Group
:
&
service
.
Group
{}}
require
.
Equal
(
t
,
"gpt-5.4"
,
resolveOpenAIMessagesDispatchMappedModel
(
apiKey
,
"claude-opus-4-6"
))
require
.
Equal
(
t
,
"gpt-5.3-codex"
,
resolveOpenAIMessagesDispatchMappedModel
(
apiKey
,
"claude-sonnet-4-5-20250929"
))
require
.
Equal
(
t
,
"gpt-5.4-mini"
,
resolveOpenAIMessagesDispatchMappedModel
(
apiKey
,
"claude-haiku-4-5-20251001"
))
})
t
.
Run
(
"returns_empty_for_non_claude_or_missing_group"
,
func
(
t
*
testing
.
T
)
{
require
.
Empty
(
t
,
resolveOpenAIMessagesDispatchMappedModel
(
nil
,
"claude-sonnet-4-5-20250929"
))
require
.
Empty
(
t
,
resolveOpenAIMessagesDispatchMappedModel
(
&
service
.
APIKey
{},
"claude-sonnet-4-5-20250929"
))
require
.
Empty
(
t
,
resolveOpenAIMessagesDispatchMappedModel
(
&
service
.
APIKey
{
Group
:
&
service
.
Group
{}},
"gpt-5.4"
))
})
t
.
Run
(
"does_not_fall_back_to_group_default_mapped_model"
,
func
(
t
*
testing
.
T
)
{
apiKey
:=
&
service
.
APIKey
{
Group
:
&
service
.
Group
{
DefaultMappedModel
:
"gpt-5.4"
,
},
}
require
.
Empty
(
t
,
resolveOpenAIMessagesDispatchMappedModel
(
apiKey
,
"gpt-5.4"
))
require
.
Equal
(
t
,
"gpt-5.3-codex"
,
resolveOpenAIMessagesDispatchMappedModel
(
apiKey
,
"claude-sonnet-4-5-20250929"
))
})
}
func
TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable
(
t
*
testing
.
T
)
{
func
TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment