diff --git a/go.mod b/go.mod index f31bd218..b668bd0a 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,10 @@ require ( require github.com/pkg/errors v0.9.1 // indirect -require firebase.google.com/go/v4 v4.10.0 +require ( + firebase.google.com/go/v4 v4.10.0 + github.com/stripe/stripe-go/v74 v74.5.0 +) require ( cloud.google.com/go v0.107.0 // indirect @@ -46,10 +49,6 @@ require ( github.com/googleapis/gax-go/v2 v2.7.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect - github.com/stripe/stripe-go/v74 v74.5.0 // indirect - github.com/tidwall/gjson v1.14.4 // indirect - github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.1 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/net v0.4.0 // indirect diff --git a/go.sum b/go.sum index ce4367cb..98096334 100644 --- a/go.sum +++ b/go.sum @@ -102,13 +102,6 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stripe/stripe-go/v74 v74.5.0 h1:YyqTvVQdS34KYGCfVB87EMn9eDV3FCFkSwfdOQhiVL4= github.com/stripe/stripe-go/v74 v74.5.0/go.mod h1:5PoXNp30AJ3tGq57ZcFuaMylzNi8KpwlrYAFmO1fHZw= -github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= -github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= -github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY= github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= diff --git a/server/config.go b/server/config.go index 9df62d8a..b54cb54e 100644 --- a/server/config.go +++ b/server/config.go @@ -19,6 +19,7 @@ const ( DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs) DefaultFirebaseQuotaExceededPenaltyDuration = 10 * time.Minute // Time that over-users are locked out of Firebase if it returns "quota exceeded" + DefaultStripePriceCacheDuration = time.Hour // Time to keep Stripe prices cached in memory before a refresh is needed ) // Defines all global and per-visitor limits @@ -112,10 +113,12 @@ type Config struct { BehindProxy bool StripeSecretKey string StripeWebhookKey string + StripePriceCacheDuration time.Duration EnableWeb bool EnableSignup bool // Enable creation of accounts via API and UI EnableLogin bool EnableReservations bool // Allow users with role "user" to own/reserve topics + AccessControlAllowOrigin string // CORS header field to restrict access from web clients Version string // injected by App } @@ -132,9 +135,11 @@ func NewConfig() *Config { FirebaseKeyFile: "", CacheFile: "", CacheDuration: DefaultCacheDuration, + CacheStartupQueries: "", CacheBatchSize: 0, CacheBatchTimeout: 0, AuthFile: "", + AuthStartupQueries: "", AuthDefault: user.NewPermission(true, true), AttachmentCacheDir: "", AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit, @@ -142,14 +147,24 @@ func NewConfig() *Config { AttachmentExpiryDuration: DefaultAttachmentExpiryDuration, KeepaliveInterval: DefaultKeepaliveInterval, ManagerInterval: DefaultManagerInterval, - MessageLimit: DefaultMessageLengthLimit, - MinDelay: DefaultMinDelay, - MaxDelay: DefaultMaxDelay, + WebRootIsApp: false, DelayedSenderInterval: DefaultDelayedSenderInterval, FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval, FirebasePollInterval: DefaultFirebasePollInterval, FirebaseQuotaExceededPenaltyDuration: DefaultFirebaseQuotaExceededPenaltyDuration, + UpstreamBaseURL: "", + SMTPSenderAddr: "", + SMTPSenderUser: "", + SMTPSenderPass: "", + SMTPSenderFrom: "", + SMTPServerListen: "", + SMTPServerDomain: "", + SMTPServerAddrPrefix: "", + MessageLimit: DefaultMessageLengthLimit, + MinDelay: DefaultMinDelay, + MaxDelay: DefaultMaxDelay, TotalTopicLimit: DefaultTotalTopicLimit, + TotalAttachmentSizeLimit: 0, VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit, VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit, VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit, @@ -162,7 +177,14 @@ func NewConfig() *Config { VisitorAccountCreateLimitReplenish: DefaultVisitorAccountCreateLimitReplenish, VisitorStatsResetTime: DefaultVisitorStatsResetTime, BehindProxy: false, + StripeSecretKey: "", + StripeWebhookKey: "", + StripePriceCacheDuration: DefaultStripePriceCacheDuration, EnableWeb: true, + EnableSignup: false, + EnableLogin: false, + EnableReservations: false, + AccessControlAllowOrigin: "*", Version: "", } } diff --git a/server/server.go b/server/server.go index d24f4b4b..16323b2d 100644 --- a/server/server.go +++ b/server/server.go @@ -39,21 +39,18 @@ import ( payments: - send dunning emails when overdue - payment methods - - unmarshal to stripe.Subscription instead of gjson - delete subscription when account deleted - delete messages + reserved topics on ResetTier - - move v1/account/tiers to v1/tiers - Limits & rate limiting: users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address! login/account endpoints when ResetStats() is run, reset messagesLimiter (and others)? - update last_seen when API is accessed Make sure account endpoints make sense for admins UI: + - revert home page change - flicker of upgrade banner - JS constants Sync: @@ -82,7 +79,7 @@ type Server struct { userManager *user.Manager // Might be nil! messageCache *messageCache fileCache *fileCache - priceCache map[string]string // Stripe price ID -> formatted price + priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price closeChan chan bool mu sync.Mutex } @@ -144,7 +141,8 @@ const ( emptyMessageBody = "triggered" // Used if message body is empty newMessageBody = "New message" // Used in poll requests as generic message defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment - encodingBase64 = "base64" + encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages + jsonBodyBytesLimit = 16384 ) // WebSocket constants @@ -201,7 +199,7 @@ func New(conf *Config) (*Server, error) { topics: topics, userManager: userManager, visitors: make(map[string]*visitor), - priceCache: make(map[string]string), + priceCache: util.NewLookupCache(fetchStripePrices, conf.StripePriceCacheDuration), }, nil } @@ -454,22 +452,14 @@ func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) } func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests - _, err := io.WriteString(w, `{"success":true}`+"\n") - return err + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error { response := &apiHealthResponse{ Healthy: true, } - w.Header().Set("Content-Type", "text/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil + return s.writeJSON(w, response) } func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error { @@ -620,12 +610,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests - if err := json.NewEncoder(w).Encode(m); err != nil { - return err - } - return nil + return s.writeJSON(w, m) } func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -1175,8 +1160,8 @@ func parseSince(r *http.Request, poll bool) (sinceMarker, error) { func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error { w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE") - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests - w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible? + w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests + w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible? return nil } @@ -1482,7 +1467,7 @@ func (s *Server) limitRequests(next handleFunc) handleFunc { // before passing it on to the next handler. This is meant to be used in combination with handlePublish. func (s *Server) transformBodyJSON(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit) + m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit*2) // 2x to account for JSON format overhead if err != nil { return err } @@ -1650,3 +1635,12 @@ func (s *Server) visitorFromIP(ip netip.Addr) *visitor { func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor { return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user) } + +func (s *Server) writeJSON(w http.ResponseWriter, v any) error { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests + if err := json.NewEncoder(w).Encode(v); err != nil { + return err + } + return nil +} diff --git a/server/server_account.go b/server/server_account.go index db3adac7..8414c9aa 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -10,7 +10,6 @@ import ( ) const ( - jsonBodyBytesLimit = 4096 subscriptionIDLength = 16 createdByAPI = "api" syncTopicAccountSyncEvent = "sync" @@ -38,9 +37,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v * if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - return nil + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *visitor) error { @@ -118,21 +115,14 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis response.Username = user.Everyone response.Role = string(user.RoleAnonymous) } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil + return s.writeJSON(w, response) } func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error { if err := s.userManager.RemoveUser(v.user.Name); err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - return nil + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -143,9 +133,7 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ if err := s.userManager.ChangePassword(v.user.Name, newPassword.Password); err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - return nil + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) handleAccountTokenIssue(w http.ResponseWriter, _ *http.Request, v *visitor) error { @@ -154,16 +142,11 @@ func (s *Server) handleAccountTokenIssue(w http.ResponseWriter, _ *http.Request, if err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this response := &apiAccountTokenResponse{ Token: token.Value, Expires: token.Expires.Unix(), } - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil + return s.writeJSON(w, response) } func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, _ *http.Request, v *visitor) error { @@ -177,16 +160,11 @@ func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, _ *http.Request if err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this response := &apiAccountTokenResponse{ Token: token.Value, Expires: token.Expires.Unix(), } - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil + return s.writeJSON(w, response) } func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error { @@ -197,8 +175,7 @@ func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, _ *http.Request if err := s.userManager.RemoveToken(v.user); err != nil { return err } - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - return nil + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -230,9 +207,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ if err := s.userManager.ChangeSettings(v.user); err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - return nil + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -257,12 +232,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req return err } } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(newSubscription); err != nil { - return err - } - return nil + return s.writeJSON(w, newSubscription) } func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -292,12 +262,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http. if err := s.userManager.ChangeSettings(v.user); err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(subscription); err != nil { - return err - } - return nil + return s.writeJSON(w, subscription) } func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -321,9 +286,7 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http. return err } } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - return nil + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -366,9 +329,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyone.IsRead(), everyone.IsWrite()); err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - return nil + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -392,9 +353,7 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R if err := s.userManager.ResetAccess(user.Everyone, topic); err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - return nil + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) publishSyncEvent(v *visitor) error { diff --git a/server/server_payments.go b/server/server_payments.go index 4e79cdcc..5a3440be 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "encoding/json" "errors" "fmt" @@ -11,19 +12,15 @@ import ( "github.com/stripe/stripe-go/v74/price" "github.com/stripe/stripe-go/v74/subscription" "github.com/stripe/stripe-go/v74/webhook" - "github.com/tidwall/gjson" "heckel.io/ntfy/log" "heckel.io/ntfy/user" "heckel.io/ntfy/util" + "io" "net/http" "net/netip" "time" ) -const ( - stripeBodyBytesLimit = 16384 -) - var ( errNotAPaidTier = errors.New("tier does not have billing price identifier") errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions") @@ -52,23 +49,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ }, }, } + prices, err := s.priceCache.Value() + if err != nil { + return err + } for _, tier := range tiers { - if tier.StripePriceID == "" { + priceStr, ok := prices[tier.StripePriceID] + if tier.StripePriceID == "" || !ok { continue } - priceStr, ok := s.priceCache[tier.StripePriceID] - if !ok { - p, err := price.Get(tier.StripePriceID, nil) - if err != nil { - return err - } - if p.UnitAmount%100 == 0 { - priceStr = fmt.Sprintf("$%d", p.UnitAmount/100) - } else { - priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100) - } - s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something - } response = append(response, &apiAccountBillingTier{ Code: tier.Code, Name: tier.Name, @@ -84,12 +73,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ }, }) } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil + return s.writeJSON(w, response) } // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier @@ -143,12 +127,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r response := &apiAccountBillingSubscriptionCreateResponse{ RedirectURL: sess.URL, } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil + return s.writeJSON(w, response) } func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error { @@ -219,12 +198,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r if err != nil { return err } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil { - return err - } - return nil + return s.writeJSON(w, newSuccessResponse()) } // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, @@ -239,12 +213,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r return err } } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil { - return err - } - return nil + return s.writeJSON(w, newSuccessResponse()) } func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -262,12 +231,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, response := &apiAccountBillingPortalRedirectResponse{ RedirectURL: ps.URL, } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(response); err != nil { - return err - } - return nil + return s.writeJSON(w, response) } // handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync @@ -278,7 +242,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ if stripeSignature == "" { return errHTTPBadRequestBillingRequestInvalid } - body, err := util.Peek(r.Body, stripeBodyBytesLimit) + body, err := util.Peek(r.Body, jsonBodyBytesLimit) if err != nil { return err } else if body.LimitReached { @@ -302,25 +266,23 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ } func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error { - subscriptionID := gjson.GetBytes(event, "id") - customerID := gjson.GetBytes(event, "customer") - status := gjson.GetBytes(event, "status") - currentPeriodEnd := gjson.GetBytes(event, "current_period_end") - cancelAt := gjson.GetBytes(event, "cancel_at") - priceID := gjson.GetBytes(event, "items.data.0.price.id") - if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() { + r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event))) + if err != nil { + return err + } else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" { return errHTTPBadRequestBillingRequestInvalid } - log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID) - u, err := s.userManager.UserByStripeCustomer(customerID.String()) + subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID + log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID) + u, err := s.userManager.UserByStripeCustomer(r.Customer) if err != nil { return err } - tier, err := s.userManager.TierByStripePrice(priceID.String()) + tier, err := s.userManager.TierByStripePrice(priceID) if err != nil { return err } - if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil { + if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil { return err } s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) @@ -328,16 +290,18 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe } func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error { - customerID := gjson.GetBytes(event, "customer") - if !customerID.Exists() { + r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event))) + if err != nil { + return err + } else if r.Customer == "" { return errHTTPBadRequestBillingRequestInvalid } - log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String()) - u, err := s.userManager.UserByStripeCustomer(customerID.String()) + log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer) + u, err := s.userManager.UserByStripeCustomer(r.Customer) if err != nil { return err } - if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil { + if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil { return err } s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) @@ -364,3 +328,27 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio } return nil } + +// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices +// in memory, and ultimately for the web app to display the price table. +func fetchStripePrices() (map[string]string, error) { + log.Debug("Caching prices from Stripe API") + prices := make(map[string]string) + iter := price.List(&stripe.PriceListParams{ + Active: stripe.Bool(true), + }) + for iter.Next() { + p := iter.Price() + if p.UnitAmount%100 == 0 { + prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100) + } else { + prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100) + } + log.Trace("- Caching price %s = %v", p.ID, prices[p.ID]) + } + if iter.Err() != nil { + log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error()) + return nil, iter.Err() + } + return prices, nil +} diff --git a/server/server_test.go b/server/server_test.go index 27527883..6d69d6c6 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1463,7 +1463,7 @@ func TestServer_PublishAttachmentBandwidthLimit(t *testing.T) { msg := toMessage(t, response.Body.String()) require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") - // Get it 4 times successfully + // Value it 4 times successfully path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345") for i := 1; i <= 4; i++ { // 4 successful downloads response = request(t, s, "GET", path, "", nil) diff --git a/server/types.go b/server/types.go index 76ba6d4c..5d0f04b5 100644 --- a/server/types.go +++ b/server/types.go @@ -336,3 +336,22 @@ func newSuccessResponse() *apiSuccessResponse { Success: true, } } + +type apiStripeSubscriptionUpdatedEvent struct { + ID string `json:"id"` + Customer string `json:"customer"` + Status string `json:"status"` + CurrentPeriodEnd int64 `json:"current_period_end"` + CancelAt int64 `json:"cancel_at"` + Items *struct { + Data []*struct { + Price *struct { + ID string `json:"id"` + } `json:"price"` + } `json:"data"` + } `json:"items"` +} + +type apiStripeSubscriptionDeletedEvent struct { + Customer string `json:"customer"` +} diff --git a/user/manager.go b/user/manager.go index 4c733cdb..8fe2a0f7 100644 --- a/user/manager.go +++ b/user/manager.go @@ -66,7 +66,6 @@ const ( stripe_subscription_cancel_at INT, created_by TEXT NOT NULL, created_at INT NOT NULL, - last_seen INT NOT NULL, FOREIGN KEY (tier_id) REFERENCES tier (id) ); CREATE UNIQUE INDEX idx_user ON user (user); @@ -93,8 +92,8 @@ const ( id INT PRIMARY KEY, version INT NOT NULL ); - INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at, last_seen) - VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH(), 0) + INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) + VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH()) ON CONFLICT (id) DO NOTHING; ` createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;` @@ -130,8 +129,8 @@ const ( ` insertUserQuery = ` - INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO user (user, pass, role, sync_topic, created_by, created_at) + VALUES (?, ?, ?, ?, ?, ?) ` selectUsernamesQuery = ` SELECT user @@ -257,8 +256,8 @@ const ( ALTER TABLE user RENAME TO user_old; ` migrate1To2InsertFromOldTablesAndDropNoTx = ` - INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen) - SELECT user, pass, role, '', 'admin', UNIXEPOCH(), UNIXEPOCH() FROM user_old; + INSERT INTO user (user, pass, role, sync_topic, created_by, created_at) + SELECT user, pass, role, '', 'admin', UNIXEPOCH() FROM user_old; INSERT INTO user_access (user_id, topic, read, write) SELECT u.id, a.topic, a.read, a.write @@ -531,7 +530,7 @@ func (a *Manager) AddUser(username, password string, role Role, createdBy string return err } syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix() - if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now, now); err != nil { + if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now); err != nil { return err } return nil @@ -589,6 +588,7 @@ func (a *Manager) User(username string) (*User, error) { return a.readUser(rows) } +// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) { rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID) if err != nil { @@ -878,6 +878,7 @@ func (a *Manager) CreateTier(tier *Tier) error { return nil } +// ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information func (a *Manager) ChangeBilling(user *User) error { if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(user.Billing.StripeSubscriptionCancelAt.Unix()), user.Name); err != nil { return err @@ -885,6 +886,7 @@ func (a *Manager) ChangeBilling(user *User) error { return nil } +// Tiers returns a list of all Tier structs func (a *Manager) Tiers() ([]*Tier, error) { rows, err := a.db.Query(selectTiersQuery) if err != nil { @@ -904,6 +906,7 @@ func (a *Manager) Tiers() ([]*Tier, error) { return tiers, nil } +// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist func (a *Manager) Tier(code string) (*Tier, error) { rows, err := a.db.Query(selectTierByCodeQuery, code) if err != nil { @@ -913,6 +916,7 @@ func (a *Manager) Tier(code string) (*Tier, error) { return a.readTier(rows) } +// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { rows, err := a.db.Query(selectTierByPriceIDQuery, priceID) if err != nil { diff --git a/util/lookup_cache.go b/util/lookup_cache.go new file mode 100644 index 00000000..a22f36ab --- /dev/null +++ b/util/lookup_cache.go @@ -0,0 +1,52 @@ +package util + +import ( + "sync" + "time" +) + +// LookupCache is a single-value cache with a time-to-live (TTL). The cache has a lookup function +// to retrieve the value and stores it until TTL is reached. +// +// Example: +// +// lookup := func() (string, error) { +// r, _ := http.Get("...") +// s, _ := io.ReadAll(r.Body) +// return string(s), nil +// } +// c := NewLookupCache[string](lookup, time.Hour) +// fmt.Println(c.Get()) // Fetches the string via HTTP +// fmt.Println(c.Get()) // Uses cached value +type LookupCache[T any] struct { + value *T + lookup func() (T, error) + ttl time.Duration + updated time.Time + mu sync.Mutex +} + +// NewLookupCache creates a new LookupCache with a given time-to-live (TTL) +func NewLookupCache[T any](lookup func() (T, error), ttl time.Duration) *LookupCache[T] { + return &LookupCache[T]{ + value: nil, + lookup: lookup, + ttl: ttl, + } +} + +// Value returns the cached value, or retrieves it via the lookup function +func (c *LookupCache[T]) Value() (T, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.value == nil || (c.ttl > 0 && time.Since(c.updated) > c.ttl) { + value, err := c.lookup() + if err != nil { + var t T + return t, err + } + c.value = &value + c.updated = time.Now() + } + return *c.value, nil +} diff --git a/util/lookup_cache_test.go b/util/lookup_cache_test.go new file mode 100644 index 00000000..5d45af34 --- /dev/null +++ b/util/lookup_cache_test.go @@ -0,0 +1,63 @@ +package util + +import ( + "errors" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestLookupCache_Success(t *testing.T) { + values, i := []string{"first", "second"}, 0 + c := NewLookupCache[string](func() (string, error) { + time.Sleep(300 * time.Millisecond) + v := values[i] + i++ + return v, nil + }, 500*time.Millisecond) + + start := time.Now() + v, err := c.Value() + require.Nil(t, err) + require.Equal(t, values[0], v) + require.True(t, time.Since(start) >= 300*time.Millisecond) + + start = time.Now() + v, err = c.Value() + require.Nil(t, err) + require.Equal(t, values[0], v) + require.True(t, time.Since(start) < 200*time.Millisecond) + + time.Sleep(550 * time.Millisecond) + + start = time.Now() + v, err = c.Value() + require.Nil(t, err) + require.Equal(t, values[1], v) + require.True(t, time.Since(start) >= 300*time.Millisecond) + + start = time.Now() + v, err = c.Value() + require.Nil(t, err) + require.Equal(t, values[1], v) + require.True(t, time.Since(start) < 200*time.Millisecond) +} + +func TestLookupCache_Error(t *testing.T) { + c := NewLookupCache[string](func() (string, error) { + time.Sleep(200 * time.Millisecond) + return "", errors.New("some error") + }, 500*time.Millisecond) + + start := time.Now() + v, err := c.Value() + require.NotNil(t, err) + require.Equal(t, "", v) + require.True(t, time.Since(start) >= 200*time.Millisecond) + + start = time.Now() + v, err = c.Value() + require.NotNil(t, err) + require.Equal(t, "", v) + require.True(t, time.Since(start) >= 200*time.Millisecond) +} diff --git a/web/src/app/AccountApi.js b/web/src/app/AccountApi.js index 6431d680..05b3d6b6 100644 --- a/web/src/app/AccountApi.js +++ b/web/src/app/AccountApi.js @@ -24,11 +24,6 @@ class AccountApi { constructor() { this.timer = null; this.listener = null; // Fired when account is fetched from remote - - // Random ID used to identify this client when sending/receiving "sync" events - // to the sync topic of an account. This ID doesn't matter much, but it will prevent - // a client from reacting to its own message. - this.identity = Math.floor(Math.random() * 2586000); } registerListener(listener) { diff --git a/web/src/app/config.js b/web/src/app/config.js index 0cb0bb1b..bdec53ed 100644 --- a/web/src/app/config.js +++ b/web/src/app/config.js @@ -1,6 +1,8 @@ const config = window.config; -if (config.base_url === "") { +// The backend returns an empty base_url for the config struct, +// so the frontend (hey, that's us!) can use the current location. +if (!config.base_url || config.base_url === "") { config.base_url = window.location.origin; } diff --git a/web/src/app/db.js b/web/src/app/db.js index 31eba294..564ee1ce 100644 --- a/web/src/app/db.js +++ b/web/src/app/db.js @@ -7,6 +7,7 @@ import session from "./Session"; // Notes: // - As per docs, we only declare the indexable columns, not all columns +// The IndexedDB database name is based on the logged-in user const dbName = (session.username()) ? `ntfy-${session.username()}` : "ntfy"; const db = new Dexie(dbName); diff --git a/web/src/components/hooks.js b/web/src/components/hooks.js index 77d6d995..5dd036af 100644 --- a/web/src/components/hooks.js +++ b/web/src/components/hooks.js @@ -35,12 +35,8 @@ export const useConnectionListeners = (subscriptions, users) => { try { const data = JSON.parse(message.message); if (data.event === "sync") { - if (data.source !== accountApi.identity) { - console.log(`[ConnectionListener] Triggering account sync`); - await accountApi.sync(); - } else { - console.log(`[ConnectionListener] I triggered the account sync, ignoring message`); - } + console.log(`[ConnectionListener] Triggering account sync`); + await accountApi.sync(); } else { console.log(`[ConnectionListener] Unknown message type. Doing nothing.`); }