Commit e5a77853 authored by Forest's avatar Forest
Browse files

refactor: 调整项目结构为单向依赖

parent b3463769
...@@ -48,8 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -48,8 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier) turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService) emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository) userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(authService, userService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db) apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db) groupRepository := repository.NewGroupRepository(db)
......
...@@ -22,12 +22,14 @@ require ( ...@@ -22,12 +22,14 @@ require (
golang.org/x/net v0.47.0 golang.org/x/net v0.47.0
golang.org/x/term v0.37.0 golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/datatypes v1.2.0
gorm.io/driver/postgres v1.5.4 gorm.io/driver/postgres v1.5.4
gorm.io/gorm v1.25.5 gorm.io/gorm v1.25.5
) )
require ( require (
dario.cat/mergo v1.0.2 // indirect dario.cat/mergo v1.0.2 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect
...@@ -57,6 +59,7 @@ require ( ...@@ -57,6 +59,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect github.com/google/subcommands v1.2.0 // indirect
...@@ -64,8 +67,8 @@ require ( ...@@ -64,8 +67,8 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/pgx/v5 v5.5.4 // indirect github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
...@@ -132,4 +135,5 @@ require ( ...@@ -132,4 +135,5 @@ require (
google.golang.org/grpc v1.75.1 // indirect google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
gorm.io/driver/mysql v1.5.2 // indirect
) )
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
...@@ -77,10 +79,17 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn ...@@ -77,10 +79,17 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
...@@ -104,10 +113,10 @@ github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo= ...@@ -104,10 +113,10 @@ github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk= github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
...@@ -135,8 +144,12 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S ...@@ -135,8 +144,12 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
...@@ -319,8 +332,17 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= ...@@ -319,8 +332,17 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs=
gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
......
...@@ -3,7 +3,7 @@ package admin ...@@ -3,7 +3,7 @@ package admin
import ( import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -102,7 +102,7 @@ type BulkUpdateAccountsRequest struct { ...@@ -102,7 +102,7 @@ type BulkUpdateAccountsRequest struct {
// AccountWithConcurrency extends Account with real-time concurrency info // AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct { type AccountWithConcurrency struct {
*model.Account *dto.Account
CurrentConcurrency int `json:"current_concurrency"` CurrentConcurrency int `json:"current_concurrency"`
} }
...@@ -137,7 +137,7 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -137,7 +137,7 @@ func (h *AccountHandler) List(c *gin.Context) {
result := make([]AccountWithConcurrency, len(accounts)) result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts { for i := range accounts {
result[i] = AccountWithConcurrency{ result[i] = AccountWithConcurrency{
Account: &accounts[i], Account: dto.AccountFromService(&accounts[i]),
CurrentConcurrency: concurrencyCounts[accounts[i].ID], CurrentConcurrency: concurrencyCounts[accounts[i].ID],
} }
} }
...@@ -160,7 +160,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) { ...@@ -160,7 +160,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// Create handles creating a new account // Create handles creating a new account
...@@ -188,7 +188,7 @@ func (h *AccountHandler) Create(c *gin.Context) { ...@@ -188,7 +188,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// Update handles updating an account // Update handles updating an account
...@@ -222,7 +222,7 @@ func (h *AccountHandler) Update(c *gin.Context) { ...@@ -222,7 +222,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// Delete handles deleting an account // Delete handles deleting an account
...@@ -425,7 +425,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) { ...@@ -425,7 +425,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// BatchCreate handles batch creating accounts // BatchCreate handles batch creating accounts
...@@ -801,7 +801,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) { ...@@ -801,7 +801,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// GetAvailableModels handles getting available models for an account // GetAvailableModels handles getting available models for an account
......
package admin package admin
import ( import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"strconv"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
......
...@@ -3,7 +3,7 @@ package admin ...@@ -3,7 +3,7 @@ package admin
import ( import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -69,7 +69,11 @@ func (h *GroupHandler) List(c *gin.Context) { ...@@ -69,7 +69,11 @@ func (h *GroupHandler) List(c *gin.Context) {
return return
} }
response.Paginated(c, groups, total, page, pageSize) outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
} }
// GetAll handles getting all active groups without pagination // GetAll handles getting all active groups without pagination
...@@ -77,7 +81,7 @@ func (h *GroupHandler) List(c *gin.Context) { ...@@ -77,7 +81,7 @@ func (h *GroupHandler) List(c *gin.Context) {
func (h *GroupHandler) GetAll(c *gin.Context) { func (h *GroupHandler) GetAll(c *gin.Context) {
platform := c.Query("platform") platform := c.Query("platform")
var groups []model.Group var groups []service.Group
var err error var err error
if platform != "" { if platform != "" {
...@@ -91,7 +95,11 @@ func (h *GroupHandler) GetAll(c *gin.Context) { ...@@ -91,7 +95,11 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
return return
} }
response.Success(c, groups) outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Success(c, outGroups)
} }
// GetByID handles getting a group by ID // GetByID handles getting a group by ID
...@@ -109,7 +117,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) { ...@@ -109,7 +117,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, group) response.Success(c, dto.GroupFromService(group))
} }
// Create handles creating a new group // Create handles creating a new group
...@@ -137,7 +145,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -137,7 +145,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
return return
} }
response.Success(c, group) response.Success(c, dto.GroupFromService(group))
} }
// Update handles updating a group // Update handles updating a group
...@@ -172,7 +180,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -172,7 +180,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
return return
} }
response.Success(c, group) response.Success(c, dto.GroupFromService(group))
} }
// Delete handles deleting a group // Delete handles deleting a group
...@@ -229,5 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { ...@@ -229,5 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
return return
} }
response.Paginated(c, keys, total, page, pageSize) outKeys := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
} }
package admin package admin
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -31,7 +31,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -31,7 +31,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
return return
} }
response.Success(c, settings) response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SmtpHost: settings.SmtpHost,
SmtpPort: settings.SmtpPort,
SmtpUsername: settings.SmtpUsername,
SmtpPassword: settings.SmtpPassword,
SmtpFrom: settings.SmtpFrom,
SmtpFromName: settings.SmtpFromName,
SmtpUseTLS: settings.SmtpUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
})
} }
// UpdateSettingsRequest 更新设置请求 // UpdateSettingsRequest 更新设置请求
...@@ -87,7 +108,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -87,7 +108,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SmtpPort = 587 req.SmtpPort = 587
} }
settings := &model.SystemSettings{ settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled, RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
SmtpHost: req.SmtpHost, SmtpHost: req.SmtpHost,
...@@ -122,7 +143,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -122,7 +143,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return return
} }
response.Success(c, updatedSettings) response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SmtpHost: updatedSettings.SmtpHost,
SmtpPort: updatedSettings.SmtpPort,
SmtpUsername: updatedSettings.SmtpUsername,
SmtpPassword: updatedSettings.SmtpPassword,
SmtpFrom: updatedSettings.SmtpFrom,
SmtpFromName: updatedSettings.SmtpFromName,
SmtpUseTLS: updatedSettings.SmtpUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
ApiBaseUrl: updatedSettings.ApiBaseUrl,
ContactInfo: updatedSettings.ContactInfo,
DocUrl: updatedSettings.DocUrl,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
})
} }
// TestSmtpRequest 测试SMTP连接请求 // TestSmtpRequest 测试SMTP连接请求
......
...@@ -3,9 +3,10 @@ package admin ...@@ -3,9 +3,10 @@ package admin
import ( import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -82,7 +83,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) { ...@@ -82,7 +83,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
return return
} }
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination)) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
} }
// GetByID handles getting a subscription by ID // GetByID handles getting a subscription by ID
...@@ -100,7 +105,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) { ...@@ -100,7 +105,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, subscription) response.Success(c, dto.UserSubscriptionFromService(subscription))
} }
// GetProgress handles getting subscription usage progress // GetProgress handles getting subscription usage progress
...@@ -145,7 +150,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) { ...@@ -145,7 +150,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
return return
} }
response.Success(c, subscription) response.Success(c, dto.UserSubscriptionFromService(subscription))
} }
// BulkAssign handles bulk assigning subscriptions to multiple users // BulkAssign handles bulk assigning subscriptions to multiple users
...@@ -196,7 +201,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { ...@@ -196,7 +201,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return return
} }
response.Success(c, subscription) response.Success(c, dto.UserSubscriptionFromService(subscription))
} }
// Revoke handles revoking a subscription // Revoke handles revoking a subscription
...@@ -234,7 +239,11 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) { ...@@ -234,7 +239,11 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
return return
} }
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination)) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
} }
// ListByUser handles listing subscriptions for a specific user // ListByUser handles listing subscriptions for a specific user
...@@ -252,15 +261,18 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) { ...@@ -252,15 +261,18 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
return return
} }
response.Success(c, subscriptions) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
} }
// Helper function to get admin ID from context // Helper function to get admin ID from context
func getAdminIDFromContext(c *gin.Context) int64 { func getAdminIDFromContext(c *gin.Context) int64 {
if user, exists := c.Get("user"); exists { subject, ok := middleware2.GetAuthSubjectFromContext(c)
if u, ok := user.(*model.User); ok && u != nil { if !ok {
return u.ID return 0
}
} }
return 0 return subject.UserID
} }
...@@ -3,9 +3,10 @@ package handler ...@@ -3,9 +3,10 @@ package handler
import ( import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -40,42 +41,34 @@ type UpdateAPIKeyRequest struct { ...@@ -40,42 +41,34 @@ type UpdateAPIKeyRequest struct {
// List handles listing user's API keys with pagination // List handles listing user's API keys with pagination
// GET /api/v1/api-keys // GET /api/v1/api-keys
func (h *APIKeyHandler) List(c *gin.Context) { func (h *APIKeyHandler) List(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params) keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Paginated(c, keys, result.Total, page, pageSize) out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
} }
// GetByID handles getting a single API key // GetByID handles getting a single API key
// GET /api/v1/api-keys/:id // GET /api/v1/api-keys/:id
func (h *APIKeyHandler) GetByID(c *gin.Context) { func (h *APIKeyHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -92,26 +85,20 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) { ...@@ -92,26 +85,20 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
} }
// 验证所有权 // 验证所有权
if key.UserID != user.ID { if key.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this key") response.Forbidden(c, "Not authorized to access this key")
return return
} }
response.Success(c, key) response.Success(c, dto.ApiKeyFromService(key))
} }
// Create handles creating a new API key // Create handles creating a new API key
// POST /api/v1/api-keys // POST /api/v1/api-keys
func (h *APIKeyHandler) Create(c *gin.Context) { func (h *APIKeyHandler) Create(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -126,27 +113,21 @@ func (h *APIKeyHandler) Create(c *gin.Context) { ...@@ -126,27 +113,21 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
GroupID: req.GroupID, GroupID: req.GroupID,
CustomKey: req.CustomKey, CustomKey: req.CustomKey,
} }
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq) key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, key) response.Success(c, dto.ApiKeyFromService(key))
} }
// Update handles updating an API key // Update handles updating an API key
// PUT /api/v1/api-keys/:id // PUT /api/v1/api-keys/:id
func (h *APIKeyHandler) Update(c *gin.Context) { func (h *APIKeyHandler) Update(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -171,27 +152,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) { ...@@ -171,27 +152,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
svcReq.Status = &req.Status svcReq.Status = &req.Status
} }
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq) key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, key) response.Success(c, dto.ApiKeyFromService(key))
} }
// Delete handles deleting an API key // Delete handles deleting an API key
// DELETE /api/v1/api-keys/:id // DELETE /api/v1/api-keys/:id
func (h *APIKeyHandler) Delete(c *gin.Context) { func (h *APIKeyHandler) Delete(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -201,7 +176,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) { ...@@ -201,7 +176,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
return return
} }
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID) err = h.apiKeyService.Delete(c.Request.Context(), keyID, subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -213,23 +188,21 @@ func (h *APIKeyHandler) Delete(c *gin.Context) { ...@@ -213,23 +188,21 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
// GetAvailableGroups 获取用户可以绑定的分组列表 // GetAvailableGroups 获取用户可以绑定的分组列表
// GET /api/v1/groups/available // GET /api/v1/groups/available
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) { func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID) groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, groups) out := make([]dto.Group, 0, len(groups))
for i := range groups {
out = append(out, *dto.GroupFromService(&groups[i]))
}
response.Success(c, out)
} }
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -11,12 +12,14 @@ import ( ...@@ -11,12 +12,14 @@ import (
// AuthHandler handles authentication-related requests // AuthHandler handles authentication-related requests
type AuthHandler struct { type AuthHandler struct {
authService *service.AuthService authService *service.AuthService
userService *service.UserService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService) *AuthHandler { func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
authService: authService, authService: authService,
userService: userService,
} }
} }
...@@ -49,9 +52,9 @@ type LoginRequest struct { ...@@ -49,9 +52,9 @@ type LoginRequest struct {
// AuthResponse 认证响应格式(匹配前端期望) // AuthResponse 认证响应格式(匹配前端期望)
type AuthResponse struct { type AuthResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
User *model.User `json:"user"` User *dto.User `json:"user"`
} }
// Register handles user registration // Register handles user registration
...@@ -80,7 +83,7 @@ func (h *AuthHandler) Register(c *gin.Context) { ...@@ -80,7 +83,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
response.Success(c, AuthResponse{ response.Success(c, AuthResponse{
AccessToken: token, AccessToken: token,
TokenType: "Bearer", TokenType: "Bearer",
User: user, User: dto.UserFromService(user),
}) })
} }
...@@ -135,24 +138,24 @@ func (h *AuthHandler) Login(c *gin.Context) { ...@@ -135,24 +138,24 @@ func (h *AuthHandler) Login(c *gin.Context) {
response.Success(c, AuthResponse{ response.Success(c, AuthResponse{
AccessToken: token, AccessToken: token,
TokenType: "Bearer", TokenType: "Bearer",
User: user, User: dto.UserFromService(user),
}) })
} }
// GetCurrentUser handles getting current authenticated user // GetCurrentUser handles getting current authenticated user
// GET /api/v1/auth/me // GET /api/v1/auth/me
func (h *AuthHandler) GetCurrentUser(c *gin.Context) { func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists { if !ok {
response.Unauthorized(c, "User not authenticated") response.Unauthorized(c, "User not authenticated")
return return
} }
user, ok := userValue.(*model.User) user, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if !ok { if err != nil {
response.InternalError(c, "Invalid user context") response.ErrorFrom(c, err)
return return
} }
response.Success(c, user) response.Success(c, dto.UserFromService(user))
} }
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
func UserFromServiceShallow(u *service.User) *User {
if u == nil {
return nil
}
return &User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func UserFromService(u *service.User) *User {
if u == nil {
return nil
}
out := UserFromServiceShallow(u)
if len(u.ApiKeys) > 0 {
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
for i := range u.ApiKeys {
k := u.ApiKeys[i]
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
}
}
if len(u.Subscriptions) > 0 {
out.Subscriptions = make([]UserSubscription, 0, len(u.Subscriptions))
for i := range u.Subscriptions {
s := u.Subscriptions[i]
out.Subscriptions = append(out.Subscriptions, *UserSubscriptionFromService(&s))
}
}
return out
}
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
if k == nil {
return nil
}
return &ApiKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
}
}
func GroupFromServiceShallow(g *service.Group) *Group {
if g == nil {
return nil
}
return &Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
out := GroupFromServiceShallow(g)
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
ag := g.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
return out
}
func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
return &Account{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
}
func AccountFromService(a *service.Account) *Account {
if a == nil {
return nil
}
out := AccountFromServiceShallow(a)
out.Proxy = ProxyFromService(a.Proxy)
if len(a.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(a.AccountGroups))
for i := range a.AccountGroups {
ag := a.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
if len(a.Groups) > 0 {
out.Groups = make([]*Group, 0, len(a.Groups))
for _, g := range a.Groups {
out.Groups = append(out.Groups, GroupFromServiceShallow(g))
}
}
return out
}
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
if ag == nil {
return nil
}
return &AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
Account: AccountFromServiceShallow(ag.Account),
Group: GroupFromServiceShallow(ag.Group),
}
}
func ProxyFromService(p *service.Proxy) *Proxy {
if p == nil {
return nil
}
return &Proxy{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWithAccountCount {
if p == nil {
return nil
}
return &ProxyWithAccountCount{
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
}
}
func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
if rc == nil {
return nil
}
return &RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
Value: rc.Value,
Status: rc.Status,
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
Notes: rc.Notes,
CreatedAt: rc.CreatedAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
}
func UsageLogFromService(l *service.UsageLog) *UsageLog {
if l == nil {
return nil
}
return &UsageLog{
ID: l.ID,
UserID: l.UserID,
ApiKeyID: l.ApiKeyID,
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,
OutputTokens: l.OutputTokens,
CacheCreationTokens: l.CacheCreationTokens,
CacheReadTokens: l.CacheReadTokens,
CacheCreation5mTokens: l.CacheCreation5mTokens,
CacheCreation1hTokens: l.CacheCreation1hTokens,
InputCost: l.InputCost,
OutputCost: l.OutputCost,
CacheCreationCost: l.CacheCreationCost,
CacheReadCost: l.CacheReadCost,
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
FirstTokenMs: l.FirstTokenMs,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
ApiKey: ApiKeyFromService(l.ApiKey),
Account: AccountFromService(l.Account),
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
}
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil
}
return &Setting{
ID: s.ID,
Key: s.Key,
Value: s.Value,
UpdatedAt: s.UpdatedAt,
}
}
func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscription {
if sub == nil {
return nil
}
return &UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
GroupID: sub.GroupID,
StartsAt: sub.StartsAt,
ExpiresAt: sub.ExpiresAt,
Status: sub.Status,
DailyWindowStart: sub.DailyWindowStart,
WeeklyWindowStart: sub.WeeklyWindowStart,
MonthlyWindowStart: sub.MonthlyWindowStart,
DailyUsageUSD: sub.DailyUsageUSD,
WeeklyUsageUSD: sub.WeeklyUsageUSD,
MonthlyUsageUSD: sub.MonthlyUsageUSD,
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
CreatedAt: sub.CreatedAt,
UpdatedAt: sub.UpdatedAt,
User: UserFromServiceShallow(sub.User),
Group: GroupFromServiceShallow(sub.Group),
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}
package dto
import "time"
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ApiKeys []ApiKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
type ApiKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Group struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
}
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Schedulable bool `json:"schedulable"`
RateLimitedAt *time.Time `json:"rate_limited_at"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
GroupIDs []int64 `json:"group_ids,omitempty"`
Groups []*Group `json:"groups,omitempty"`
}
type AccountGroup struct {
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Priority int `json:"priority"`
CreatedAt time.Time `json:"created_at"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Proxy struct {
ID int64 `json:"id"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"-"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}
type RedeemCode struct {
ID int64 `json:"id"`
Code string `json:"code"`
Type string `json:"type"`
Value float64 `json:"value"`
Status string `json:"status"`
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationTokens int `json:"cache_creation_tokens"`
CacheReadTokens int `json:"cache_read_tokens"`
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
ApiKey *ApiKey `json:"api_key,omitempty"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
type Setting struct {
ID int64 `json:"id"`
Key string `json:"key"`
Value string `json:"value"`
UpdatedAt time.Time `json:"updated_at"`
}
type UserSubscription struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
GroupID int64 `json:"group_id"`
StartsAt time.Time `json:"starts_at"`
ExpiresAt time.Time `json:"expires_at"`
Status string `json:"status"`
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
DailyUsageUSD float64 `json:"daily_usage_usd"`
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
}
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -47,7 +46,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -47,7 +46,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
user, ok := middleware2.GetUserFromContext(c) subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
...@@ -82,8 +81,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -82,8 +81,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. 检查wait队列是否已满 // 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency) maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed // On error, allow request to proceed
...@@ -92,10 +91,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -92,10 +91,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 确保在函数退出时减少wait计数 // 确保在函数退出时减少wait计数
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID) defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位 // 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
...@@ -106,7 +105,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -106,7 +105,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 2. 【新增】Wait后二次检查余额/订阅 // 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return return
...@@ -133,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -133,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted) accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
...@@ -158,7 +157,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -158,7 +157,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ApiKey: apiKey, ApiKey: apiKey,
User: user, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
}); err != nil { }); err != nil {
...@@ -198,7 +197,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) { ...@@ -198,7 +197,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return return
} }
user, ok := middleware2.GetUserFromContext(c) subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return return
...@@ -223,7 +222,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) { ...@@ -223,7 +222,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
} }
// 余额模式:返回钱包余额 // 余额模式:返回钱包余额
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID) latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return return
...@@ -241,7 +240,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) { ...@@ -241,7 +240,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
// 逻辑: // 逻辑:
// 1. 如果日/周/月任一限额达到100%,返回0 // 1. 如果日/周/月任一限额达到100%,返回0
// 2. 否则返回所有已配置周期中剩余额度的最小值 // 2. 否则返回所有已配置周期中剩余额度的最小值
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 { func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, sub *service.UserSubscription) float64 {
var remainingValues []float64 var remainingValues []float64
// 检查日限额 // 检查日限额
...@@ -334,7 +333,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -334,7 +333,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return return
} }
user, ok := middleware2.GetUserFromContext(c) _, ok = middleware2.GetAuthSubjectFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
...@@ -366,7 +365,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -366,7 +365,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额) // 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额 // 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error()) h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return return
} }
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -69,11 +68,11 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64 ...@@ -69,11 +68,11 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency) result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -83,17 +82,17 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model. ...@@ -83,17 +82,17 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted) return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
} }
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary. // AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency) result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -103,7 +102,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account * ...@@ -103,7 +102,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted) return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
} }
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
......
...@@ -46,7 +46,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -46,7 +46,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
user, ok := middleware2.GetUserFromContext(c) subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
...@@ -94,8 +94,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -94,8 +94,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full // 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(user.Concurrency) maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed // On error, allow request to proceed
...@@ -104,10 +104,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -104,10 +104,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
// Ensure wait count is decremented when function exits // Ensure wait count is decremented when function exits
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID) defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. First acquire user concurrency slot // 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
...@@ -118,7 +118,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -118,7 +118,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
// 2. Re-check billing eligibility after wait // 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return return
...@@ -138,7 +138,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -138,7 +138,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot // 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted) accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
...@@ -163,7 +163,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -163,7 +163,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
ApiKey: apiKey, ApiKey: apiKey,
User: user, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
}); err != nil { }); err != nil {
......
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -37,15 +38,9 @@ type RedeemResponse struct { ...@@ -37,15 +38,9 @@ type RedeemResponse struct {
// Redeem handles redeeming a code // Redeem handles redeeming a code
// POST /api/v1/redeem // POST /api/v1/redeem
func (h *RedeemHandler) Redeem(c *gin.Context) { func (h *RedeemHandler) Redeem(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) { ...@@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
return return
} }
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code) result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, result) response.Success(c, dto.RedeemCodeFromService(result))
} }
// GetHistory returns the user's redemption history // GetHistory returns the user's redemption history
// GET /api/v1/redeem/history // GET /api/v1/redeem/history
func (h *RedeemHandler) GetHistory(c *gin.Context) { func (h *RedeemHandler) GetHistory(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
// Default limit is 25 // Default limit is 25
limit := 25 limit := 25
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit) codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, codes) out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Success(c, out)
} }
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -30,6 +31,17 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -30,6 +31,17 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
return return
} }
settings.Version = h.version response.Success(c, dto.PublicSettings{
response.Success(c, settings) RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
Version: h.version,
})
} }
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct { ...@@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct {
// SubscriptionProgressInfo represents subscription with progress info // SubscriptionProgressInfo represents subscription with progress info
type SubscriptionProgressInfo struct { type SubscriptionProgressInfo struct {
Subscription *model.UserSubscription `json:"subscription"` Subscription *dto.UserSubscription `json:"subscription"`
Progress *service.SubscriptionProgress `json:"progress"` Progress *service.SubscriptionProgress `json:"progress"`
} }
...@@ -44,68 +45,58 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S ...@@ -44,68 +45,58 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S
// List handles listing current user's subscriptions // List handles listing current user's subscriptions
// GET /api/v1/subscriptions // GET /api/v1/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) { func (h *SubscriptionHandler) List(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user in context") response.Unauthorized(c, "User not found in context")
return return
} }
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, subscriptions) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
} }
// GetActive handles getting current user's active subscriptions // GetActive handles getting current user's active subscriptions
// GET /api/v1/subscriptions/active // GET /api/v1/subscriptions/active
func (h *SubscriptionHandler) GetActive(c *gin.Context) { func (h *SubscriptionHandler) GetActive(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user in context") response.Unauthorized(c, "User not found in context")
return return
} }
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, subscriptions) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
} }
// GetProgress handles getting subscription progress for current user // GetProgress handles getting subscription progress for current user
// GET /api/v1/subscriptions/progress // GET /api/v1/subscriptions/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) { func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user in context") response.Unauthorized(c, "User not found in context")
return return
} }
// Get all active subscriptions with progress // Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) { ...@@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
continue continue
} }
result = append(result, SubscriptionProgressInfo{ result = append(result, SubscriptionProgressInfo{
Subscription: sub, Subscription: dto.UserSubscriptionFromService(sub),
Progress: progress, Progress: progress,
}) })
} }
...@@ -131,20 +122,14 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) { ...@@ -131,20 +122,14 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
// GetSummary handles getting a summary of current user's subscription status // GetSummary handles getting a summary of current user's subscription status
// GET /api/v1/subscriptions/summary // GET /api/v1/subscriptions/summary
func (h *SubscriptionHandler) GetSummary(c *gin.Context) { func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user in context") response.Unauthorized(c, "User not found in context")
return return
} }
// Get all active subscriptions // Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -4,10 +4,11 @@ import ( ...@@ -4,10 +4,11 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -30,15 +31,9 @@ func NewUsageHandler(usageService *service.UsageService, apiKeyService *service. ...@@ -30,15 +31,9 @@ func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.
// List handles listing usage records with pagination // List handles listing usage records with pagination
// GET /api/v1/usage // GET /api/v1/usage
func (h *UsageHandler) List(c *gin.Context) { func (h *UsageHandler) List(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -58,7 +53,7 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -58,7 +53,7 @@ func (h *UsageHandler) List(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if apiKey.UserID != user.ID { if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's usage records") response.Forbidden(c, "Not authorized to access this API key's usage records")
return return
} }
...@@ -67,35 +62,33 @@ func (h *UsageHandler) List(c *gin.Context) { ...@@ -67,35 +62,33 @@ func (h *UsageHandler) List(c *gin.Context) {
} }
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog var records []service.UsageLog
var result *pagination.PaginationResult var result *pagination.PaginationResult
var err error var err error
if apiKeyID > 0 { if apiKeyID > 0 {
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params) records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
} else { } else {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params) records, result, err = h.usageService.ListByUser(c.Request.Context(), subject.UserID, params)
} }
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Paginated(c, records, result.Total, page, pageSize) out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
} }
// GetByID handles getting a single usage record // GetByID handles getting a single usage record
// GET /api/v1/usage/:id // GET /api/v1/usage/:id
func (h *UsageHandler) GetByID(c *gin.Context) { func (h *UsageHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -112,26 +105,20 @@ func (h *UsageHandler) GetByID(c *gin.Context) { ...@@ -112,26 +105,20 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
} }
// 验证所有权 // 验证所有权
if record.UserID != user.ID { if record.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this record") response.Forbidden(c, "Not authorized to access this record")
return return
} }
response.Success(c, record) response.Success(c, dto.UsageLogFromService(record))
} }
// Stats handles getting usage statistics // Stats handles getting usage statistics
// GET /api/v1/usage/stats // GET /api/v1/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) { func (h *UsageHandler) Stats(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -149,7 +136,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -149,7 +136,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.NotFound(c, "API key not found") response.NotFound(c, "API key not found")
return return
} }
if apiKey.UserID != user.ID { if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's statistics") response.Forbidden(c, "Not authorized to access this API key's statistics")
return return
} }
...@@ -201,7 +188,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { ...@@ -201,7 +188,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 { if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime) stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else { } else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime) stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
} }
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -245,19 +232,13 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) { ...@@ -245,19 +232,13 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
// DashboardStats handles getting user dashboard statistics // DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats // GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) { func (h *UsageHandler) DashboardStats(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID) stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -269,22 +250,16 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) { ...@@ -269,22 +250,16 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
// DashboardTrend handles getting user usage trend data // DashboardTrend handles getting user usage trend data
// GET /api/v1/usage/dashboard/trend // GET /api/v1/usage/dashboard/trend
func (h *UsageHandler) DashboardTrend(c *gin.Context) { func (h *UsageHandler) DashboardTrend(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day") granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity) trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -301,21 +276,15 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) { ...@@ -301,21 +276,15 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
// DashboardModels handles getting user model usage statistics // DashboardModels handles getting user model usage statistics
// GET /api/v1/usage/dashboard/models // GET /api/v1/usage/dashboard/models
func (h *UsageHandler) DashboardModels(c *gin.Context) { func (h *UsageHandler) DashboardModels(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime) stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -336,15 +305,9 @@ type BatchApiKeysUsageRequest struct { ...@@ -336,15 +305,9 @@ type BatchApiKeysUsageRequest struct {
// DashboardApiKeysUsage handles getting usage stats for user's own API keys // DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage // POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
...@@ -360,7 +323,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { ...@@ -360,7 +323,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
} }
// Verify ownership of all requested API keys // Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000}) userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment