mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-21 19:03:26 +01:00
WIP: Stripe integration
This commit is contained in:
parent
7007c0a0bd
commit
01fd4754f9
20 changed files with 557 additions and 43 deletions
|
@ -103,7 +103,7 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic
|
|||
read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms)
|
||||
write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms)
|
||||
u, err := manager.User(username)
|
||||
if err == user.ErrNotFound {
|
||||
if err == user.ErrUserNotFound {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
} else if u.Role == user.RoleAdmin {
|
||||
return fmt.Errorf("user %s is an admin user, access control entries have no effect", username)
|
||||
|
@ -173,7 +173,7 @@ func showAllAccess(c *cli.Context, manager *user.Manager) error {
|
|||
|
||||
func showUserAccess(c *cli.Context, manager *user.Manager, username string) error {
|
||||
users, err := manager.User(username)
|
||||
if err == user.ErrNotFound {
|
||||
if err == user.ErrUserNotFound {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
} else if err != nil {
|
||||
return err
|
||||
|
|
21
cmd/serve.go
21
cmd/serve.go
|
@ -5,6 +5,7 @@ package cmd
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/stripe/stripe-go/v74"
|
||||
"heckel.io/ntfy/user"
|
||||
"io/fs"
|
||||
"math"
|
||||
|
@ -61,7 +62,6 @@ var flagsServe = append(
|
|||
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-signup", Aliases: []string{"enable_signup"}, EnvVars: []string{"NTFY_ENABLE_SIGNUP"}, Value: false, Usage: "allows users to sign up via the web app, or API"}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-login", Aliases: []string{"enable_login"}, EnvVars: []string{"NTFY_ENABLE_LOGIN"}, Value: false, Usage: "allows users to log in via the web app, or API"}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-reservations", Aliases: []string{"enable_reservations"}, EnvVars: []string{"NTFY_ENABLE_RESERVATIONS"}, Value: false, Usage: "allows users to reserve topics (if their tier allows it)"}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-payments", Aliases: []string{"enable_payments"}, EnvVars: []string{"NTFY_ENABLE_PAYMENTS"}, Value: false, Usage: "enables payments integration [preliminary option, may change]"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "upstream-base-url", Aliases: []string{"upstream_base_url"}, EnvVars: []string{"NTFY_UPSTREAM_BASE_URL"}, Value: "", Usage: "forward poll request to an upstream server, this is needed for iOS push notifications for self-hosted servers"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", Aliases: []string{"smtp_sender_addr"}, EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-user", Aliases: []string{"smtp_sender_user"}, EnvVars: []string{"NTFY_SMTP_SENDER_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}),
|
||||
|
@ -80,6 +80,8 @@ var flagsServe = append(
|
|||
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-key", Aliases: []string{"stripe_key"}, EnvVars: []string{"NTFY_STRIPE_KEY"}, Value: "", Usage: "xxxxxxxxxxxxx"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "xxxxxxxxxxxx"}),
|
||||
)
|
||||
|
||||
var cmdServe = &cli.Command{
|
||||
|
@ -132,7 +134,6 @@ func execServe(c *cli.Context) error {
|
|||
webRoot := c.String("web-root")
|
||||
enableSignup := c.Bool("enable-signup")
|
||||
enableLogin := c.Bool("enable-login")
|
||||
enablePayments := c.Bool("enable-payments")
|
||||
enableReservations := c.Bool("enable-reservations")
|
||||
upstreamBaseURL := c.String("upstream-base-url")
|
||||
smtpSenderAddr := c.String("smtp-sender-addr")
|
||||
|
@ -152,6 +153,8 @@ func execServe(c *cli.Context) error {
|
|||
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
|
||||
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
|
||||
behindProxy := c.Bool("behind-proxy")
|
||||
stripeKey := c.String("stripe-key")
|
||||
stripeWebhookKey := c.String("stripe-webhook-key")
|
||||
|
||||
// Check values
|
||||
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
||||
|
@ -188,14 +191,17 @@ func execServe(c *cli.Context) error {
|
|||
return errors.New("if upstream-base-url is set, base-url must also be set")
|
||||
} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
|
||||
return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
|
||||
} else if authFile == "" && (enableSignup || enableLogin || enableReservations || enablePayments) {
|
||||
return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or enable-payments if auth-file is not set")
|
||||
} else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeKey != "") {
|
||||
return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-key if auth-file is not set")
|
||||
} else if enableSignup && !enableLogin {
|
||||
return errors.New("cannot set enable-signup without also setting enable-login")
|
||||
} else if stripeKey != "" && (stripeWebhookKey == "" || baseURL == "") {
|
||||
return errors.New("if stripe-key is set, stripe-webhook-key and base-url must also be set")
|
||||
}
|
||||
|
||||
webRootIsApp := webRoot == "app"
|
||||
enableWeb := webRoot != "disable"
|
||||
enablePayments := stripeKey != ""
|
||||
|
||||
// Default auth permissions
|
||||
authDefault, err := user.ParsePermission(authDefaultAccess)
|
||||
|
@ -239,6 +245,11 @@ func execServe(c *cli.Context) error {
|
|||
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...)
|
||||
}
|
||||
|
||||
// Stripe things
|
||||
if stripeKey != "" {
|
||||
stripe.Key = stripeKey
|
||||
}
|
||||
|
||||
// Run server
|
||||
conf := server.NewConfig()
|
||||
conf.BaseURL = baseURL
|
||||
|
@ -282,6 +293,8 @@ func execServe(c *cli.Context) error {
|
|||
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
|
||||
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
|
||||
conf.BehindProxy = behindProxy
|
||||
conf.StripeKey = stripeKey
|
||||
conf.StripeWebhookKey = stripeWebhookKey
|
||||
conf.EnableWeb = enableWeb
|
||||
conf.EnableSignup = enableSignup
|
||||
conf.EnableLogin = enableLogin
|
||||
|
|
|
@ -215,7 +215,7 @@ func execUserDel(c *cli.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := manager.User(username); err == user.ErrNotFound {
|
||||
if _, err := manager.User(username); err == user.ErrUserNotFound {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
}
|
||||
if err := manager.RemoveUser(username); err != nil {
|
||||
|
@ -237,7 +237,7 @@ func execUserChangePass(c *cli.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := manager.User(username); err == user.ErrNotFound {
|
||||
if _, err := manager.User(username); err == user.ErrUserNotFound {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
}
|
||||
if password == "" {
|
||||
|
@ -265,7 +265,7 @@ func execUserChangeRole(c *cli.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := manager.User(username); err == user.ErrNotFound {
|
||||
if _, err := manager.User(username); err == user.ErrUserNotFound {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
}
|
||||
if err := manager.ChangeRole(username, role); err != nil {
|
||||
|
@ -289,7 +289,7 @@ func execUserChangeTier(c *cli.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := manager.User(username); err == user.ErrNotFound {
|
||||
if _, err := manager.User(username); err == user.ErrUserNotFound {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
}
|
||||
if tier == tierReset {
|
||||
|
|
4
go.mod
4
go.mod
|
@ -46,6 +46,10 @@ require (
|
|||
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
github.com/stripe/stripe-go/v74 v74.5.0 // indirect
|
||||
github.com/tidwall/gjson v1.14.4 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
golang.org/x/net v0.4.0 // indirect
|
||||
|
|
12
go.sum
12
go.sum
|
@ -95,10 +95,20 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
|
|||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stripe/stripe-go/v74 v74.5.0 h1:YyqTvVQdS34KYGCfVB87EMn9eDV3FCFkSwfdOQhiVL4=
|
||||
github.com/stripe/stripe-go/v74 v74.5.0/go.mod h1:5PoXNp30AJ3tGq57ZcFuaMylzNi8KpwlrYAFmO1fHZw=
|
||||
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
|
||||
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY=
|
||||
github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc=
|
||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
|
||||
|
@ -119,6 +129,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
|
|||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
|
@ -135,6 +146,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
|
|
@ -110,6 +110,8 @@ type Config struct {
|
|||
VisitorAccountCreateLimitReplenish time.Duration
|
||||
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
|
||||
BehindProxy bool
|
||||
StripeKey string
|
||||
StripeWebhookKey string
|
||||
EnableWeb bool
|
||||
EnableSignup bool // Enable creation of accounts via API and UI
|
||||
EnableLogin bool
|
||||
|
|
|
@ -58,6 +58,8 @@ var (
|
|||
errHTTPBadRequestJSONInvalid = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""}
|
||||
errHTTPBadRequestPermissionInvalid = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""}
|
||||
errHTTPBadRequestMakesNoSenseForAdmin = &errHTTP{40026, http.StatusBadRequest, "invalid request: this makes no sense for admins", ""}
|
||||
errHTTPBadRequestNotAPaidUser = &errHTTP{40027, http.StatusBadRequest, "invalid request: not a paid user", ""}
|
||||
errHTTPBadRequestInvalidStripeRequest = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid Stripe request", ""}
|
||||
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
|
||||
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
|
||||
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}
|
||||
|
|
|
@ -36,6 +36,10 @@ import (
|
|||
|
||||
/*
|
||||
TODO
|
||||
payments:
|
||||
- handle overdue payment (-> downgrade after 7 days)
|
||||
- delete stripe subscription when acocunt is deleted
|
||||
|
||||
Limits & rate limiting:
|
||||
users without tier: should the stats be persisted? are they meaningful?
|
||||
-> test that the visitor is based on the IP address!
|
||||
|
@ -43,6 +47,7 @@ import (
|
|||
update last_seen when API is accessed
|
||||
Make sure account endpoints make sense for admins
|
||||
|
||||
triggerChange after publishing a message
|
||||
UI:
|
||||
- flicker of upgrade banner
|
||||
- JS constants
|
||||
|
@ -100,6 +105,11 @@ var (
|
|||
accountSettingsPath = "/v1/account/settings"
|
||||
accountSubscriptionPath = "/v1/account/subscription"
|
||||
accountReservationPath = "/v1/account/reservation"
|
||||
accountBillingPortalPath = "/v1/account/billing/portal"
|
||||
accountBillingWebhookPath = "/v1/account/billing/webhook"
|
||||
accountCheckoutPath = "/v1/account/checkout"
|
||||
accountCheckoutSuccessTemplate = "/v1/account/checkout/success/{CHECKOUT_SESSION_ID}"
|
||||
accountCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/checkout/success/(.+)$`)
|
||||
accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
|
||||
accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
|
||||
matrixPushPath = "/_matrix/push/v1/notify"
|
||||
|
@ -362,6 +372,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
|||
return s.ensureUser(s.handleAccountReservationAdd)(w, r, v)
|
||||
} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) {
|
||||
return s.ensureUser(s.handleAccountReservationDelete)(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == accountCheckoutPath {
|
||||
return s.ensureUser(s.handleAccountCheckoutSessionCreate)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && accountCheckoutSuccessRegex.MatchString(r.URL.Path) {
|
||||
return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context!
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath {
|
||||
return s.ensureUser(s.handleAccountBillingPortalSessionCreate)(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath {
|
||||
return s.ensureUserManager(s.handleAccountBillingWebhookTrigger)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
|
||||
return s.handleMatrixDiscovery(w)
|
||||
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
||||
|
|
|
@ -2,6 +2,14 @@ package server
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/stripe/stripe-go/v74"
|
||||
portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
|
||||
"github.com/stripe/stripe-go/v74/checkout/session"
|
||||
"github.com/stripe/stripe-go/v74/subscription"
|
||||
"github.com/stripe/stripe-go/v74/webhook"
|
||||
"github.com/tidwall/gjson"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"net/http"
|
||||
|
@ -9,6 +17,7 @@ import (
|
|||
|
||||
const (
|
||||
jsonBodyBytesLimit = 4096
|
||||
stripeBodyBytesLimit = 16384
|
||||
subscriptionIDLength = 16
|
||||
createdByAPI = "api"
|
||||
)
|
||||
|
@ -386,3 +395,226 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
|||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountCheckoutSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tier, err := s.userManager.Tier(req.Tier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tier.StripePriceID == "" {
|
||||
log.Info("Checkout: Downgrading to no tier")
|
||||
return errors.New("not a paid tier")
|
||||
} else if v.user.Billing != nil && v.user.Billing.StripeSubscriptionID != "" {
|
||||
log.Info("Checkout: Changing tier and subscription to %s", tier.Code)
|
||||
|
||||
// Upgrade/downgrade tier
|
||||
sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params := &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(false),
|
||||
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
|
||||
Items: []*stripe.SubscriptionItemsParams{
|
||||
{
|
||||
ID: stripe.String(sub.Items.Data[0].ID),
|
||||
Price: stripe.String(tier.StripePriceID),
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = subscription.Update(sub.ID, params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response := &apiAccountCheckoutResponse{}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
// Checkout flow
|
||||
log.Info("Checkout: No existing subscription, creating checkout flow")
|
||||
}
|
||||
|
||||
successURL := s.config.BaseURL + accountCheckoutSuccessTemplate
|
||||
var stripeCustomerID *string
|
||||
if v.user.Billing != nil {
|
||||
stripeCustomerID = &v.user.Billing.StripeCustomerID
|
||||
}
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: &v.user.Name, // FIXME Should be user ID
|
||||
Customer: stripeCustomerID,
|
||||
SuccessURL: &successURL,
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(tier.StripePriceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
}
|
||||
sess, err := session.New(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response := &apiAccountCheckoutResponse{
|
||||
RedirectURL: sess.URL,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
// We don't have a v.user in this endpoint, only a userManager!
|
||||
matches := accountCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
|
||||
if len(matches) != 2 {
|
||||
return errHTTPInternalErrorInvalidPath
|
||||
}
|
||||
sessionID := matches[1]
|
||||
// FIXME how do I rate limit this?
|
||||
sess, err := session.Get(sessionID, nil)
|
||||
if err != nil {
|
||||
log.Warn("Stripe: %s", err)
|
||||
return errHTTPBadRequestInvalidStripeRequest
|
||||
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
|
||||
log.Warn("Stripe: Unexpected session, customer or subscription not found")
|
||||
return errHTTPBadRequestInvalidStripeRequest
|
||||
}
|
||||
sub, err := subscription.Get(sess.Subscription.ID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
|
||||
log.Error("Stripe: Unexpected subscription, expected exactly one line item")
|
||||
return errHTTPBadRequestInvalidStripeRequest
|
||||
}
|
||||
priceID := sub.Items.Data[0].Price.ID
|
||||
tier, err := s.userManager.TierByStripePrice(priceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u, err := s.userManager.User(sess.ClientReferenceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Billing == nil {
|
||||
u.Billing = &user.Billing{}
|
||||
}
|
||||
u.Billing.StripeCustomerID = sess.Customer.ID
|
||||
u.Billing.StripeSubscriptionID = sess.Subscription.ID
|
||||
if err := s.userManager.ChangeBilling(u); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
accountURL := s.config.BaseURL + "/account" // FIXME
|
||||
http.Redirect(w, r, accountURL, http.StatusSeeOther)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if v.user.Billing == nil {
|
||||
return errHTTPBadRequestNotAPaidUser
|
||||
}
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: stripe.String(v.user.Billing.StripeCustomerID),
|
||||
ReturnURL: stripe.String(s.config.BaseURL),
|
||||
}
|
||||
ps, err := portalsession.New(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response := &apiAccountBillingPortalRedirectResponse{
|
||||
RedirectURL: ps.URL,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookTrigger(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
// We don't have a v.user in this endpoint, only a userManager!
|
||||
stripeSignature := r.Header.Get("Stripe-Signature")
|
||||
if stripeSignature == "" {
|
||||
return errHTTPBadRequestInvalidStripeRequest
|
||||
}
|
||||
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if body.LimitReached {
|
||||
return errHTTPEntityTooLargeJSONBody
|
||||
}
|
||||
event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
|
||||
if err != nil {
|
||||
log.Warn("Stripe: invalid request: %s", err.Error())
|
||||
return errHTTPBadRequestInvalidStripeRequest
|
||||
} else if event.Data == nil || event.Data.Raw == nil {
|
||||
log.Warn("Stripe: invalid request, data is nil")
|
||||
return errHTTPBadRequestInvalidStripeRequest
|
||||
}
|
||||
log.Info("Stripe: webhook event %s received", event.Type)
|
||||
stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer")
|
||||
if !stripeCustomerID.Exists() {
|
||||
return errHTTPBadRequestInvalidStripeRequest
|
||||
}
|
||||
switch event.Type {
|
||||
case "checkout.session.completed":
|
||||
// Payment is successful and the subscription is created.
|
||||
// Provision the subscription, save the customer ID.
|
||||
return s.handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID.String(), event.Data.Raw)
|
||||
case "customer.subscription.updated":
|
||||
return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw)
|
||||
case "invoice.paid":
|
||||
// Continue to provision the subscription as payments continue to be made.
|
||||
// Store the status in your database and check when a user accesses your service.
|
||||
// This approach helps you avoid hitting rate limits.
|
||||
return nil // FIXME
|
||||
case "invoice.payment_failed":
|
||||
// The payment failed or the customer does not have a valid payment method.
|
||||
// The subscription becomes past_due. Notify your customer and send them to the
|
||||
// customer portal to update their payment information.
|
||||
return nil // FIXME
|
||||
default:
|
||||
log.Warn("Stripe: unhandled webhook %s", event.Type)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID string, event json.RawMessage) error {
|
||||
log.Info("Stripe: checkout completed for customer %s", stripeCustomerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error {
|
||||
status := gjson.GetBytes(event, "status")
|
||||
priceID := gjson.GetBytes(event, "items.data.0.price.id")
|
||||
if !status.Exists() || !priceID.Exists() {
|
||||
return errHTTPBadRequestInvalidStripeRequest
|
||||
}
|
||||
log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID)
|
||||
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tier, err := s.userManager.TierByStripePrice(priceID.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -295,3 +295,15 @@ type apiConfigResponse struct {
|
|||
EnableReservations bool `json:"enable_reservations"`
|
||||
DisallowedTopics []string `json:"disallowed_topics"`
|
||||
}
|
||||
|
||||
type apiAccountTierChangeRequest struct {
|
||||
Tier string `json:"tier"`
|
||||
}
|
||||
|
||||
type apiAccountCheckoutResponse struct {
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
}
|
||||
|
||||
type apiAccountBillingPortalRedirectResponse struct {
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
}
|
||||
|
|
110
user/manager.go
110
user/manager.go
|
@ -44,8 +44,11 @@ const (
|
|||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
stripe_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
tier_id INT,
|
||||
|
@ -56,12 +59,16 @@ const (
|
|||
sync_topic TEXT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
created_by TEXT NOT NULL,
|
||||
created_at INT NOT NULL,
|
||||
last_seen INT NOT NULL,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
|
@ -93,18 +100,24 @@ const (
|
|||
`
|
||||
|
||||
selectUserByNameQuery = `
|
||||
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
|
||||
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier p on p.id = u.tier_id
|
||||
WHERE user = ?
|
||||
`
|
||||
selectUserByTokenQuery = `
|
||||
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
|
||||
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
|
||||
FROM user u
|
||||
JOIN user_token t on u.id = t.user_id
|
||||
LEFT JOIN tier p on p.id = u.tier_id
|
||||
WHERE t.token = ? AND t.expires >= ?
|
||||
`
|
||||
selectUserByStripeCustomerIDQuery = `
|
||||
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier p on p.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = ?
|
||||
`
|
||||
selectTopicPermsQuery = `
|
||||
SELECT read, write
|
||||
FROM user_access a
|
||||
|
@ -204,9 +217,21 @@ const (
|
|||
INSERT INTO tier (code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
|
||||
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
|
||||
selectTierByCodeQuery = `
|
||||
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
|
||||
FROM tier
|
||||
WHERE code = ?
|
||||
`
|
||||
selectTierByPriceIDQuery = `
|
||||
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
|
||||
FROM tier
|
||||
WHERE stripe_price_id = ?
|
||||
`
|
||||
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
|
||||
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
|
||||
|
||||
updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
|
@ -543,7 +568,7 @@ func (a *Manager) Users() ([]*User, error) {
|
|||
return users, nil
|
||||
}
|
||||
|
||||
// User returns the user with the given username if it exists, or ErrNotFound otherwise.
|
||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
|
||||
// You may also pass Everyone to retrieve the anonymous user and its Grant list.
|
||||
func (a *Manager) User(username string) (*User, error) {
|
||||
rows, err := a.db.Query(selectUserByNameQuery, username)
|
||||
|
@ -553,6 +578,14 @@ func (a *Manager) User(username string) (*User, error) {
|
|||
return a.readUser(rows)
|
||||
}
|
||||
|
||||
func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
|
||||
rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.readUser(rows)
|
||||
}
|
||||
|
||||
func (a *Manager) userByToken(token string) (*User, error) {
|
||||
rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
|
||||
if err != nil {
|
||||
|
@ -564,14 +597,14 @@ func (a *Manager) userByToken(token string) (*User, error) {
|
|||
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
||||
defer rows.Close()
|
||||
var username, hash, role, prefs, syncTopic string
|
||||
var tierCode, tierName sql.NullString
|
||||
var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString
|
||||
var paid sql.NullBool
|
||||
var messages, emails int64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrNotFound
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration); err != nil {
|
||||
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
|
@ -590,7 +623,14 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|||
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if stripeCustomerID.Valid && stripeSubscriptionID.Valid {
|
||||
user.Billing = &Billing{
|
||||
StripeCustomerID: stripeCustomerID.String,
|
||||
StripeSubscriptionID: stripeSubscriptionID.String,
|
||||
}
|
||||
}
|
||||
if tierCode.Valid {
|
||||
// See readTier() when this is changed!
|
||||
user.Tier = &Tier{
|
||||
Code: tierCode.String,
|
||||
Name: tierName.String,
|
||||
|
@ -602,6 +642,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
StripePriceID: stripePriceID.String,
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
|
@ -826,6 +867,59 @@ func (a *Manager) CreateTier(tier *Tier) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *Manager) ChangeBilling(user *User) error {
|
||||
if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Manager) Tier(code string) (*Tier, error) {
|
||||
rows, err := a.db.Query(selectTierByCodeQuery, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.readTier(rows)
|
||||
}
|
||||
|
||||
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
||||
rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.readTier(rows)
|
||||
}
|
||||
|
||||
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
|
||||
defer rows.Close()
|
||||
var code, name string
|
||||
var stripePriceID sql.NullString
|
||||
var paid bool
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrTierNotFound
|
||||
}
|
||||
if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// When changed, note readUser() as well
|
||||
return &Tier{
|
||||
Code: code,
|
||||
Name: name,
|
||||
Paid: paid,
|
||||
MessagesLimit: messagesLimit.Int64,
|
||||
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailsLimit: emailsLimit.Int64,
|
||||
ReservationsLimit: reservationsLimit.Int64,
|
||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
StripePriceID: stripePriceID.String, // May be empty!
|
||||
}, nil
|
||||
}
|
||||
|
||||
func toSQLWildcard(s string) string {
|
||||
return strings.ReplaceAll(s, "*", "%")
|
||||
}
|
||||
|
|
|
@ -208,7 +208,7 @@ func TestManager_UserManagement(t *testing.T) {
|
|||
// Remove user
|
||||
require.Nil(t, a.RemoveUser("ben"))
|
||||
_, err = a.User("ben")
|
||||
require.Equal(t, ErrNotFound, err)
|
||||
require.Equal(t, ErrUserNotFound, err)
|
||||
|
||||
users, err = a.Users()
|
||||
require.Nil(t, err)
|
||||
|
|
|
@ -16,6 +16,7 @@ type User struct {
|
|||
Prefs *Prefs
|
||||
Tier *Tier
|
||||
Stats *Stats
|
||||
Billing *Billing
|
||||
SyncTopic string
|
||||
Created time.Time
|
||||
LastSeen time.Time
|
||||
|
@ -58,6 +59,7 @@ type Tier struct {
|
|||
AttachmentFileSizeLimit int64
|
||||
AttachmentTotalSizeLimit int64
|
||||
AttachmentExpiryDuration time.Duration
|
||||
StripePriceID string
|
||||
}
|
||||
|
||||
// Subscription represents a user's topic subscription
|
||||
|
@ -81,6 +83,12 @@ type Stats struct {
|
|||
Emails int64
|
||||
}
|
||||
|
||||
// Billing is a struct holding a user's billing information
|
||||
type Billing struct {
|
||||
StripeCustomerID string
|
||||
StripeSubscriptionID string
|
||||
}
|
||||
|
||||
// Grant is a struct that represents an access control entry to a topic by a user
|
||||
type Grant struct {
|
||||
TopicPattern string // May include wildcard (*)
|
||||
|
@ -212,5 +220,6 @@ var (
|
|||
ErrUnauthenticated = errors.New("unauthenticated")
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
ErrInvalidArgument = errors.New("invalid argument")
|
||||
ErrNotFound = errors.New("not found")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrTierNotFound = errors.New("tier not found")
|
||||
)
|
||||
|
|
|
@ -8,7 +8,7 @@ import {
|
|||
accountTokenUrl,
|
||||
accountUrl, maybeWithAuth, topicUrl,
|
||||
withBasicAuth,
|
||||
withBearerAuth
|
||||
withBearerAuth, accountCheckoutUrl, accountBillingPortalUrl
|
||||
} from "./utils";
|
||||
import session from "./Session";
|
||||
import subscriptionManager from "./SubscriptionManager";
|
||||
|
@ -228,7 +228,7 @@ class AccountApi {
|
|||
this.triggerChange(); // Dangle!
|
||||
}
|
||||
|
||||
async upsertAccess(topic, everyone) {
|
||||
async upsertReservation(topic, everyone) {
|
||||
const url = accountReservationUrl(config.base_url);
|
||||
console.log(`[AccountApi] Upserting user access to topic ${topic}, everyone=${everyone}`);
|
||||
const response = await fetch(url, {
|
||||
|
@ -249,7 +249,7 @@ class AccountApi {
|
|||
this.triggerChange(); // Dangle!
|
||||
}
|
||||
|
||||
async deleteAccess(topic) {
|
||||
async deleteReservation(topic) {
|
||||
const url = accountReservationSingleUrl(config.base_url, topic);
|
||||
console.log(`[AccountApi] Removing topic reservation ${url}`);
|
||||
const response = await fetch(url, {
|
||||
|
@ -264,6 +264,39 @@ class AccountApi {
|
|||
this.triggerChange(); // Dangle!
|
||||
}
|
||||
|
||||
async createCheckoutSession(tier) {
|
||||
const url = accountCheckoutUrl(config.base_url);
|
||||
console.log(`[AccountApi] Creating checkout session`);
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: withBearerAuth({}, session.token()),
|
||||
body: JSON.stringify({
|
||||
tier: tier
|
||||
})
|
||||
});
|
||||
if (response.status === 401 || response.status === 403) {
|
||||
throw new UnauthorizedError();
|
||||
} else if (response.status !== 200) {
|
||||
throw new Error(`Unexpected server response ${response.status}`);
|
||||
}
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
async createBillingPortalSession() {
|
||||
const url = accountBillingPortalUrl(config.base_url);
|
||||
console.log(`[AccountApi] Creating billing portal session`);
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: withBearerAuth({}, session.token())
|
||||
});
|
||||
if (response.status === 401 || response.status === 403) {
|
||||
throw new UnauthorizedError();
|
||||
} else if (response.status !== 200) {
|
||||
throw new Error(`Unexpected server response ${response.status}`);
|
||||
}
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
async sync() {
|
||||
try {
|
||||
if (!session.token()) {
|
||||
|
|
|
@ -26,6 +26,8 @@ export const accountSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/subscr
|
|||
export const accountSubscriptionSingleUrl = (baseUrl, id) => `${baseUrl}/v1/account/subscription/${id}`;
|
||||
export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reservation`;
|
||||
export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`;
|
||||
export const accountCheckoutUrl = (baseUrl) => `${baseUrl}/v1/account/checkout`;
|
||||
export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`;
|
||||
export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
|
||||
export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
|
||||
export const expandSecureUrl = (url) => `https://${url}`;
|
||||
|
|
|
@ -171,10 +171,28 @@ const Stats = () => {
|
|||
const { t } = useTranslation();
|
||||
const { account } = useContext(AccountContext);
|
||||
const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
|
||||
|
||||
if (!account) {
|
||||
return <></>;
|
||||
}
|
||||
const normalize = (value, max) => Math.min(value / max * 100, 100);
|
||||
|
||||
const normalize = (value, max) => {
|
||||
return Math.min(value / max * 100, 100);
|
||||
};
|
||||
|
||||
const handleManageBilling = async () => {
|
||||
try {
|
||||
const response = await accountApi.createBillingPortalSession();
|
||||
window.location.href = response.redirect_url;
|
||||
} catch (e) {
|
||||
console.log(`[Account] Error changing password`, e);
|
||||
if ((e instanceof UnauthorizedError)) {
|
||||
session.resetAndRedirect(routes.login);
|
||||
}
|
||||
// TODO show error
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Card sx={{p: 3}} aria-label={t("account_usage_title")}>
|
||||
<Typography variant="h5" sx={{marginBottom: 2}}>
|
||||
|
@ -201,12 +219,20 @@ const Stats = () => {
|
|||
>{t("account_usage_tier_upgrade_button")}</Button>
|
||||
}
|
||||
{config.enable_payments && account.role === "user" && account.tier?.paid &&
|
||||
<Button
|
||||
variant="outlined"
|
||||
size="small"
|
||||
onClick={() => setUpgradeDialogOpen(true)}
|
||||
sx={{ml: 1}}
|
||||
>{t("account_usage_tier_change_button")}</Button>
|
||||
<>
|
||||
<Button
|
||||
variant="outlined"
|
||||
size="small"
|
||||
onClick={() => setUpgradeDialogOpen(true)}
|
||||
sx={{ml: 1}}
|
||||
>{t("account_usage_tier_change_button")}</Button>
|
||||
<Button
|
||||
variant="outlined"
|
||||
size="small"
|
||||
onClick={handleManageBilling}
|
||||
sx={{ml: 1}}
|
||||
>Manage billing</Button>
|
||||
</>
|
||||
}
|
||||
<UpgradeDialog
|
||||
open={upgradeDialogOpen}
|
||||
|
|
|
@ -501,7 +501,7 @@ const Reservations = () => {
|
|||
const handleDialogSubmit = async (reservation) => {
|
||||
setDialogOpen(false);
|
||||
try {
|
||||
await accountApi.upsertAccess(reservation.topic, reservation.everyone);
|
||||
await accountApi.upsertReservation(reservation.topic, reservation.everyone);
|
||||
await accountApi.sync();
|
||||
console.debug(`[Preferences] Added topic reservation`, reservation);
|
||||
} catch (e) {
|
||||
|
@ -557,7 +557,7 @@ const ReservationsTable = (props) => {
|
|||
const handleDialogSubmit = async (reservation) => {
|
||||
setDialogOpen(false);
|
||||
try {
|
||||
await accountApi.upsertAccess(reservation.topic, reservation.everyone);
|
||||
await accountApi.upsertReservation(reservation.topic, reservation.everyone);
|
||||
await accountApi.sync();
|
||||
console.debug(`[Preferences] Added topic reservation`, reservation);
|
||||
} catch (e) {
|
||||
|
@ -568,7 +568,7 @@ const ReservationsTable = (props) => {
|
|||
|
||||
const handleDeleteClick = async (reservation) => {
|
||||
try {
|
||||
await accountApi.deleteAccess(reservation.topic);
|
||||
await accountApi.deleteReservation(reservation.topic);
|
||||
await accountApi.sync();
|
||||
console.debug(`[Preferences] Deleted topic reservation`, reservation);
|
||||
} catch (e) {
|
||||
|
|
|
@ -110,7 +110,7 @@ const SubscribePage = (props) => {
|
|||
if (session.exists() && baseUrl === config.base_url && reserveTopicVisible) {
|
||||
console.log(`[SubscribeDialog] Reserving topic ${topic} with everyone access ${everyone}`);
|
||||
try {
|
||||
await accountApi.upsertAccess(topic, everyone);
|
||||
await accountApi.upsertReservation(topic, everyone);
|
||||
// Account sync later after it was added
|
||||
} catch (e) {
|
||||
console.log(`[SubscribeDialog] Error reserving topic`, e);
|
||||
|
|
|
@ -37,9 +37,9 @@ const SubscriptionSettingsDialog = (props) => {
|
|||
|
||||
// Reservation
|
||||
if (reserveTopicVisible) {
|
||||
await accountApi.upsertAccess(subscription.topic, everyone);
|
||||
await accountApi.upsertReservation(subscription.topic, everyone);
|
||||
} else if (!reserveTopicVisible && subscription.reservation) { // Was removed
|
||||
await accountApi.deleteAccess(subscription.topic);
|
||||
await accountApi.deleteReservation(subscription.topic);
|
||||
}
|
||||
|
||||
// Sync account
|
||||
|
|
|
@ -2,28 +2,83 @@ import * as React from 'react';
|
|||
import Dialog from '@mui/material/Dialog';
|
||||
import DialogContent from '@mui/material/DialogContent';
|
||||
import DialogTitle from '@mui/material/DialogTitle';
|
||||
import {useMediaQuery} from "@mui/material";
|
||||
import {CardActionArea, CardContent, useMediaQuery} from "@mui/material";
|
||||
import theme from "./theme";
|
||||
import DialogFooter from "./DialogFooter";
|
||||
import Button from "@mui/material/Button";
|
||||
import accountApi, {TopicReservedError, UnauthorizedError} from "../app/AccountApi";
|
||||
import session from "../app/Session";
|
||||
import routes from "./routes";
|
||||
import {useContext, useState} from "react";
|
||||
import Card from "@mui/material/Card";
|
||||
import Typography from "@mui/material/Typography";
|
||||
import {AccountContext} from "./App";
|
||||
|
||||
const UpgradeDialog = (props) => {
|
||||
const { account } = useContext(AccountContext);
|
||||
const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
|
||||
const [selected, setSelected] = useState(account?.tier?.code || null);
|
||||
const [errorText, setErrorText] = useState("");
|
||||
|
||||
const handleSuccess = async () => {
|
||||
// TODO
|
||||
const handleCheckout = async () => {
|
||||
try {
|
||||
const response = await accountApi.createCheckoutSession(selected);
|
||||
if (response.redirect_url) {
|
||||
window.location.href = response.redirect_url;
|
||||
} else {
|
||||
await accountApi.sync();
|
||||
}
|
||||
|
||||
} catch (e) {
|
||||
console.log(`[UpgradeDialog] Error creating checkout session`, e);
|
||||
if ((e instanceof UnauthorizedError)) {
|
||||
session.resetAndRedirect(routes.login);
|
||||
}
|
||||
// FIXME show error
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={props.open} onClose={props.onCancel} fullScreen={fullScreen}>
|
||||
<Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}>
|
||||
<DialogTitle>Upgrade to Pro</DialogTitle>
|
||||
<DialogContent>
|
||||
Content
|
||||
<div style={{
|
||||
display: "flex",
|
||||
flexDirection: "row"
|
||||
}}>
|
||||
<TierCard code={null} name={"Free"} selected={selected === null} onClick={() => setSelected(null)}/>
|
||||
<TierCard code="starter" name={"Starter"} selected={selected === "starter"} onClick={() => setSelected("starter")}/>
|
||||
<TierCard code="pro" name={"Pro"} selected={selected === "pro"} onClick={() => setSelected("pro")}/>
|
||||
<TierCard code="business" name={"Business"} selected={selected === "business"} onClick={() => setSelected("business")}/>
|
||||
</div>
|
||||
</DialogContent>
|
||||
<DialogFooter>
|
||||
Footer
|
||||
<DialogFooter status={errorText}>
|
||||
<Button onClick={handleCheckout}>Checkout</Button>
|
||||
</DialogFooter>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
const TierCard = (props) => {
|
||||
const cardStyle = (props.selected) ? {
|
||||
border: "1px solid red",
|
||||
|
||||
} : {};
|
||||
return (
|
||||
<Card sx={{ m: 1, maxWidth: 345 }}>
|
||||
<CardActionArea>
|
||||
<CardContent sx={{...cardStyle}} onClick={props.onClick}>
|
||||
<Typography gutterBottom variant="h5" component="div">
|
||||
{props.name}
|
||||
</Typography>
|
||||
<Typography variant="body2" color="text.secondary">
|
||||
Lizards are a widespread group of squamate reptiles, with over 6,000
|
||||
species, ranging across all continents except Antarctica
|
||||
</Typography>
|
||||
</CardContent>
|
||||
</CardActionArea>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
export default UpgradeDialog;
|
||||
|
|
Loading…
Reference in a new issue