From 5d6051c490a93f8080fdd5b6a31acc1e21e09ad5 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 4 Feb 2023 21:26:01 -0500 Subject: [PATCH] Logging WIP --- cmd/access_test.go | 2 + log/event.go | 2 +- log/log.go | 2 +- log/types.go | 4 +- server/server.go | 13 +++-- server/server_account.go | 7 +-- server/server_matrix.go | 3 +- server/server_payments.go | 100 +++++++++++++++++++++++++++++--------- server/smtp_server.go | 2 +- server/util.go | 87 --------------------------------- server/visitor.go | 10 ++++ 11 files changed, 108 insertions(+), 124 deletions(-) diff --git a/cmd/access_test.go b/cmd/access_test.go index 6e3c5ba3..a9d1c534 100644 --- a/cmd/access_test.go +++ b/cmd/access_test.go @@ -26,6 +26,8 @@ func TestCLI_Access_Grant_And_Publish(t *testing.T) { stdin.WriteString("philpass\nphilpass\nbenpass\nbenpass") require.Nil(t, runUserCommand(app, conf, "add", "--role=admin", "phil")) require.Nil(t, runUserCommand(app, conf, "add", "ben")) + + app, stdin, _, _ = newTestApp() require.Nil(t, runAccessCommand(app, conf, "ben", "announcements", "rw")) require.Nil(t, runAccessCommand(app, conf, "ben", "sometopic", "read")) require.Nil(t, runAccessCommand(app, conf, "everyone", "announcements", "read")) diff --git a/log/event.go b/log/event.go index 81232773..f55d3099 100644 --- a/log/event.go +++ b/log/event.go @@ -76,7 +76,7 @@ func (e *Event) Fields(fields map[string]any) *Event { return e } -func (e *Event) Context(contexts ...Ctx) *Event { +func (e *Event) Context(contexts ...Contexter) *Event { for _, c := range contexts { e.Fields(c.Context()) } diff --git a/log/log.go b/log/log.go index 5eb88035..1a0b90e8 100644 --- a/log/log.go +++ b/log/log.go @@ -42,7 +42,7 @@ func Trace(message string, v ...any) { newEvent().Trace(message, v...) } -func Context(contexts ...Ctx) *Event { +func Context(contexts ...Contexter) *Event { return newEvent().Context(contexts...) } diff --git a/log/types.go b/log/types.go index e43b67cf..c5581650 100644 --- a/log/types.go +++ b/log/types.go @@ -91,7 +91,7 @@ func ToFormat(s string) Format { } } -type Ctx interface { +type Contexter interface { Context() map[string]any } @@ -101,7 +101,7 @@ func (f fieldsCtx) Context() map[string]any { return f } -func NewCtx(fields map[string]any) Ctx { +func NewCtx(fields map[string]any) Contexter { return fieldsCtx(fields) } diff --git a/server/server.go b/server/server.go index 57499533..27fe11eb 100644 --- a/server/server.go +++ b/server/server.go @@ -149,6 +149,7 @@ const ( tagManager = "manager" tagResetter = "resetter" tagWebsocket = "websocket" + tagMatrix = "matrix" ) // New instantiates a new Server. It creates the cache and adds a Firebase @@ -328,9 +329,9 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) { if websocket.IsWebSocketUpgrade(r) { isNormalError := strings.Contains(err.Error(), "i/o timeout") if isNormalError { - logvr(v, r).Tag(tagWebsocket).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error()) + logvr(v, r).Tag(tagWebsocket).Err(err).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error()) } else { - logvr(v, r).Tag(tagWebsocket).Info("WebSocket error: %s", err.Error()) + logvr(v, r).Tag(tagWebsocket).Err(err).Info("WebSocket error: %s", err.Error()) } return // Do not attempt to write to upgraded connection } @@ -711,7 +712,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) { logvm(v, m).Err(err).Warn("Unable to publish poll request") return } else if response.StatusCode != http.StatusOK { - logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d") + logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d", response.StatusCode) return } } @@ -1537,6 +1538,7 @@ func (s *Server) limitRequests(next handleFunc) handleFunc { if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { return next(w, r, v) } else if err := v.RequestAllowed(); err != nil { + logvr(v, r).Err(err).Fields(requestLimiterFields(v.RequestLimiter())).Trace("Request not allowed by rate limiter") return errHTTPTooManyRequestsLimitRequests } return next(w, r, v) @@ -1601,6 +1603,7 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit) if err != nil { + logvr(v, r).Tag(tagMatrix).Err(err).Trace("Invalid Matrix request") return err } if err := next(w, newRequest, v); err != nil { @@ -1630,7 +1633,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc u := v.User() for _, t := range topics { if err := s.userManager.Authorize(u, t.ID, perm); err != nil { - logvr(v, r).Err(err).Debug("Unauthorized") + logvr(v, r).Err(err).Field("message_topic", t.ID).Debug("Access to topic %s not authorized", t.ID) return errHTTPForbidden } } @@ -1644,7 +1647,7 @@ func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) { ip := extractIPAddress(r, s.config.BehindProxy) var u *user.User // may stay nil if no auth header! if u, err = s.authenticate(r); err != nil { - logr(r).Debug("Authentication failed: %s", err.Error()) + logr(r).Err(err).Debug("Authentication failed: %s", err.Error()) err = errHTTPUnauthorized // Always return visitor, even when error occurs! } v = s.visitor(ip, u) diff --git a/server/server_account.go b/server/server_account.go index b4ad2faf..8e6d2b2a 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -160,7 +160,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v * return err } } - if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), u, 0); err != nil { + if err := s.maybeRemoveMessagesAndExcessReservations(r, v, u, 0); err != nil { return err } logvr(v, r).Tag(tagAccount).Info("Marking user %s as deleted", u.Name) @@ -462,18 +462,19 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R // maybeRemoveMessagesAndExcessReservations deletes topic reservations for the given user (if too many for tier), // and marks associated messages for the topics as deleted. This also eventually deletes attachments. // The process relies on the manager to perform the actual deletions (see runManager). -func (s *Server) maybeRemoveMessagesAndExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error { +func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error { reservations, err := s.userManager.Reservations(u.Name) if err != nil { return err } else if int64(len(reservations)) <= reservationsLimit { + logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove") return nil } topics := make([]string, 0) for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- { topics = append(topics, reservations[i].Topic) } - log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", ")) + logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", ")) if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil { return err } diff --git a/server/server_matrix.go b/server/server_matrix.go index 99d8dc34..28ca7337 100644 --- a/server/server_matrix.go +++ b/server/server_matrix.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "heckel.io/ntfy/log" "heckel.io/ntfy/util" "io" "net/http" @@ -147,7 +146,7 @@ func writeMatrixDiscoveryResponse(w http.ResponseWriter) error { // writeMatrixError logs and writes the errMatrix to the given http.ResponseWriter as a matrixResponse func writeMatrixError(w http.ResponseWriter, r *http.Request, v *visitor, err *errMatrix) error { - log.Debug("%s Matrix gateway error: %s", logHTTPPrefix(v, r), err.Error()) + logvr(v, r).Tag(tagMatrix).Err(err).Debug("Matrix gateway error") return writeMatrixResponse(w, err.pushKey) } diff --git a/server/server_payments.go b/server/server_payments.go index 43e93c37..647da8cc 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -2,7 +2,6 @@ package server import ( "bytes" - "encoding/json" "errors" "fmt" "github.com/stripe/stripe-go/v74" @@ -121,7 +120,13 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r } else if tier.StripePriceID == "" { return errNotAPaidTier } - log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r)) + logvr(v, r). + Tag(tagPay). + Fields(map[string]any{ + "tier": tier, + "stripe_price_id": tier.StripePriceID, + }). + Info("Creating Stripe checkout flow") var stripeCustomerID *string if u.Billing.StripeCustomerID != "" { stripeCustomerID = &u.Billing.StripeCustomerID @@ -190,6 +195,18 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr return err } v.SetUser(u) + logvr(v, r). + Tag(tagPay). + Fields(map[string]any{ + "tier_id": tier.ID, + "tier_name": tier.Name, + "stripe_price_id": tier.StripePriceID, + "stripe_customer_id": sess.Customer.ID, + "stripe_subscription_id": sub.ID, + "stripe_subscription_status": string(sub.Status), + "stripe_subscription_paid_until": sub.CurrentPeriodEnd, + }). + Info("Stripe checkout flow succeeded, updating user tier and subscription") customerParams := &stripe.CustomerParams{ Params: stripe.Params{ Metadata: map[string]string{ @@ -201,7 +218,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil { return err } - if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil { + if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil { return err } http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther) @@ -223,7 +240,15 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r if err != nil { return err } - log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID) + logvr(v, r). + Tag(tagPay). + Fields(map[string]any{ + "new_tier_id": tier.ID, + "new_tier_name": tier.Name, + "new_tier_stripe_price_id": tier.StripePriceID, + // Other stripe_* fields filled by visitor context + }). + Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID) sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID) if err != nil { return err @@ -250,8 +275,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, // and cancelling the Stripe subscription entirely func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { + logvr(v, r).Tag(tagPay).Info("Deleting Stripe subscription") u := v.User() - log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID) if u.Billing.StripeSubscriptionID != "" { params := &stripe.SubscriptionParams{ CancelAtPeriodEnd: stripe.Bool(true), @@ -267,11 +292,11 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r // handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the // redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription. func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { + logvr(v, r).Tag(tagPay).Info("Creating Stripe billing portal session") u := v.User() if u.Billing.StripeCustomerID == "" { return errHTTPBadRequestNotAPaidUser } - log.Info("%s Creating billing portal session", logHTTPPrefix(v, r)) params := &stripe.BillingPortalSessionParams{ Customer: stripe.String(u.Billing.StripeCustomerID), ReturnURL: stripe.String(s.config.BaseURL), @@ -289,7 +314,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, // handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync // with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the // visitor (v) in this endpoint is the Stripe API, so we don't have u available. -func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error { +func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, v *visitor) error { stripeSignature := r.Header.Get("Stripe-Signature") if stripeSignature == "" { return errHTTPBadRequestBillingRequestInvalid @@ -308,74 +333,105 @@ func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Requ } switch event.Type { case "customer.subscription.updated": - return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw) + return s.handleAccountBillingWebhookSubscriptionUpdated(r, v, event) case "customer.subscription.deleted": - return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw) + return s.handleAccountBillingWebhookSubscriptionDeleted(r, v, event) default: - log.Warn("STRIPE Unhandled webhook event %s received", event.Type) + logvr(v, r). + Tag(tagPay). + Field("stripe_webhook_type", event.Type). + Warn("Unhandled Stripe webhook event %s received", event.Type) return nil } } -func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error { - ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event))) +func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request, v *visitor, event stripe.Event) error { + ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw))) if err != nil { return err } else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" { return errHTTPBadRequestBillingRequestInvalid } subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID - log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID) + logvr(v, r). + Tag(tagPay). + Fields(map[string]any{ + "stripe_webhook_type": event.Type, + "stripe_customer_id": ev.Customer, + "stripe_subscription_id": ev.ID, + "stripe_subscription_status": ev.Status, + "stripe_subscription_paid_until": ev.CurrentPeriodEnd, + "stripe_subscription_cancel_at": ev.CancelAt, + "stripe_price_id": priceID, + }). + Info("Updating subscription to status %s, with price %s", ev.Status, priceID) userFn := func() (*user.User, error) { return s.userManager.UserByStripeCustomer(ev.Customer) } + // We retry the user retrieval function, because during the Stripe checkout, there a race between the browser + // checkout success redirect (see handleAccountBillingSubscriptionCreateSuccess), and this webhook. The checkout + // success call is the one that updates the user with the Stripe customer ID. u, err := util.Retry[user.User](userFn, retryUserDelays...) if err != nil { return err } + v.SetUser(u) tier, err := s.userManager.TierByStripePrice(priceID) if err != nil { return err } - if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil { + if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil { return err } s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u)) return nil } -func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error { - ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event))) +func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request, v *visitor, event stripe.Event) error { + ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw))) if err != nil { return err } else if ev.Customer == "" { return errHTTPBadRequestBillingRequestInvalid } - log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID)) u, err := s.userManager.UserByStripeCustomer(ev.Customer) if err != nil { return err } - if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil { + v.SetUser(u) + logvr(v, r). + Tag(tagPay). + Field("stripe_webhook_type", event.Type). + Info("Subscription deleted, downgrading to unpaid tier") + if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil { return err } s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u)) return nil } -func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { +func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { reservationsLimit := visitorDefaultReservationsLimit if tier != nil { reservationsLimit = tier.ReservationLimit } - if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil { + if err := s.maybeRemoveMessagesAndExcessReservations(r, v, u, reservationsLimit); err != nil { return err } - if tier == nil { + if tier == nil && u.Tier != nil { + logvr(v, r).Tag(tagPay).Info("Resetting tier for user %s", u.Name) if err := s.userManager.ResetTier(u.Name); err != nil { return err } - } else { + } else if tier != nil && u.TierID() != tier.ID { + logvr(v, r). + Tag(tagPay). + Fields(map[string]any{ + "new_tier_id": tier.ID, + "new_tier_name": tier.Name, + "new_tier_stripe_price_id": tier.StripePriceID, + }). + Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name) if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { return err } diff --git a/server/smtp_server.go b/server/smtp_server.go index 52f8f851..869e8aae 100644 --- a/server/smtp_server.go +++ b/server/smtp_server.go @@ -70,7 +70,7 @@ func (s *smtpSession) AuthPlain(username, password string) error { } func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error { - logem(s.state).Debug("%s MAIL FROM: %s (with options: %#v)", from, opts) + logem(s.state).Debug("MAIL FROM: %s (with options: %#v)", from, opts) return nil } diff --git a/server/util.go b/server/util.go index 2fabf135..8fbfaefa 100644 --- a/server/util.go +++ b/server/util.go @@ -1,15 +1,12 @@ package server import ( - "fmt" - "github.com/emersion/go-smtp" "heckel.io/ntfy/log" "heckel.io/ntfy/util" "io" "net/http" "net/netip" "strings" - "unicode/utf8" ) func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { @@ -48,90 +45,6 @@ func readQueryParam(r *http.Request, names ...string) string { return "" } -func logr(r *http.Request) *log.Event { - return log.Fields(logFieldsHTTP(r)) -} - -func logv(v *visitor) *log.Event { - return log.Context(v) -} - -func logvr(v *visitor, r *http.Request) *log.Event { - return logv(v).Fields(logFieldsHTTP(r)) -} - -func logvrm(v *visitor, r *http.Request, m *message) *log.Event { - return logvr(v, r).Context(m) -} - -func logvm(v *visitor, m *message) *log.Event { - return logv(v).Context(m) -} - -func logem(state *smtp.ConnectionState) *log.Event { - return log. - Tag(tagSMTP). - Fields(map[string]any{ - "smtp_hostname": state.Hostname, - "smtp_remote_addr": state.RemoteAddr.String(), - }) -} - -func logFieldsHTTP(r *http.Request) map[string]any { - requestURI := r.RequestURI - if requestURI == "" { - requestURI = r.URL.Path - } - return map[string]any{ - "http_method": r.Method, - "http_path": requestURI, - } -} - -func logHTTPPrefix(v *visitor, r *http.Request) string { - requestURI := r.RequestURI - if requestURI == "" { - requestURI = r.URL.Path - } - return fmt.Sprintf("HTTP %s %s %s", v.String(), r.Method, requestURI) -} - -func logStripePrefix(customerID, subscriptionID string) string { - if subscriptionID != "" { - return fmt.Sprintf("STRIPE %s/%s", customerID, subscriptionID) - } - return fmt.Sprintf("STRIPE %s", customerID) -} - -func renderHTTPRequest(r *http.Request) string { - peekLimit := 4096 - lines := fmt.Sprintf("%s %s %s\n", r.Method, r.URL.RequestURI(), r.Proto) - for key, values := range r.Header { - for _, value := range values { - lines += fmt.Sprintf("%s: %s\n", key, value) - } - } - lines += "\n" - body, err := util.Peek(r.Body, peekLimit) - if err != nil { - lines = fmt.Sprintf("(could not read body: %s)\n", err.Error()) - } else if utf8.Valid(body.PeekedBytes) { - lines += string(body.PeekedBytes) - if body.LimitReached { - lines += fmt.Sprintf(" ... (peeked %d bytes)", peekLimit) - } - lines += "\n" - } else { - if body.LimitReached { - lines += fmt.Sprintf("(peeked bytes not UTF-8, peek limit of %d bytes reached, hex: %x ...)\n", peekLimit, body.PeekedBytes) - } else { - lines += fmt.Sprintf("(peeked bytes not UTF-8, %d bytes, hex: %x)\n", len(body.PeekedBytes), body.PeekedBytes) - } - } - r.Body = body // Important: Reset body, so it can be re-read - return strings.TrimSpace(lines) -} - func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr { remoteAddr := r.RemoteAddr addrPort, err := netip.ParseAddrPort(remoteAddr) diff --git a/server/visitor.go b/server/visitor.go index 444e576a..6b96a785 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -159,6 +159,10 @@ func (v *visitor) Context() map[string]any { if v.user != nil { fields["user_id"] = v.user.ID fields["user_name"] = v.user.Name + if v.user.Tier != nil { + fields["tier_id"] = v.user.Tier.ID + fields["tier_name"] = v.user.Tier.Name + } if v.user.Billing.StripeCustomerID != "" { fields["stripe_customer_id"] = v.user.Billing.StripeCustomerID } @@ -178,6 +182,12 @@ func (v *visitor) RequestAllowed() error { return nil } +func (v *visitor) RequestLimiter() *rate.Limiter { + v.mu.Lock() // limiters could be replaced! + defer v.mu.Unlock() + return v.requestLimiter +} + func (v *visitor) FirebaseAllowed() error { v.mu.Lock() defer v.mu.Unlock()