diff --git a/server/server_test.go b/server/server_test.go index 8ab59570..f3789e0f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -746,7 +746,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) { require.Equal(t, 401, response.Code) } -func TestServer_StatsResetter_User_Without_Tier(t *testing.T) { +func TestServer_StatsResetter(t *testing.T) { // This tests the stats resetter for // - an anonymous user // - a user without a tier (treated like the same as the anonymous user) @@ -841,6 +841,34 @@ func TestServer_StatsResetter_User_Without_Tier(t *testing.T) { require.Equal(t, int64(0), account.Stats.Messages) } +func TestServer_StatsResetter_MessageLimiter(t *testing.T) { + // This tests that the messageLimiter (the only fixed limiter) is reset by the stats resetter + + c := newTestConfigWithAuthFile(t) + s := newTestServer(t, c) + + // Publish some messages, and check stats + for i := 0; i < 3; i++ { + response := request(t, s, "PUT", "/mytopic", "test", nil) + require.Equal(t, 200, response.Code) + } + rr := request(t, s, "GET", "/v1/account", "", nil) + require.Equal(t, 200, rr.Code) + account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Nil(t, err) + require.Equal(t, int64(3), account.Stats.Messages) + require.Equal(t, int64(3), s.visitor(netip.MustParseAddr("9.9.9.9"), nil).messagesLimiter.Value()) + + // Reset stats and check again + s.resetStats() + rr = request(t, s, "GET", "/v1/account", "", nil) + require.Equal(t, 200, rr.Code) + account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Nil(t, err) + require.Equal(t, int64(0), account.Stats.Messages) + require.Equal(t, int64(0), s.visitor(netip.MustParseAddr("9.9.9.9"), nil).messagesLimiter.Value()) +} + type testMailer struct { count int mu sync.Mutex diff --git a/server/visitor.go b/server/visitor.go index 113662f3..b8f89533 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -55,19 +55,19 @@ var ( type visitor struct { config *Config messageCache *messageCache - userManager *user.Manager // May be nil - ip netip.Addr // Visitor IP address - user *user.User // Only set if authenticated user, otherwise nil - messages int64 // Number of messages sent, reset every day - emails int64 // Number of emails sent, reset every day - requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages) - messagesLimiter util.Limiter // Rate limiter for messages, may be nil - emailsLimiter *rate.Limiter // Rate limiter for emails - subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections) - bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads - accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil - firebase time.Time // Next allowed Firebase message - seen time.Time // Last seen time of this visitor (needed for removal of stale visitors) + userManager *user.Manager // May be nil + ip netip.Addr // Visitor IP address + user *user.User // Only set if authenticated user, otherwise nil + messages int64 // Number of messages sent, reset every day + emails int64 // Number of emails sent, reset every day + requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages) + messagesLimiter *util.FixedLimiter // Rate limiter for messages, may be nil + emailsLimiter *rate.Limiter // Rate limiter for emails + subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections) + bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads + accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil + firebase time.Time // Next allowed Firebase message + seen time.Time // Last seen time of this visitor (needed for removal of stale visitors) mu sync.Mutex } @@ -251,10 +251,12 @@ func (v *visitor) ResetStats() { defer v.mu.Unlock() v.messages = 0 v.emails = 0 + if v.messagesLimiter != nil { + v.messagesLimiter.Reset() + } if v.user != nil { v.user.Stats.Messages = 0 v.user.Stats.Emails = 0 - // v.messagesLimiter = ... // FIXME } } diff --git a/util/limit.go b/util/limit.go index 7f39c4c4..dd3b56fb 100644 --- a/util/limit.go +++ b/util/limit.go @@ -50,6 +50,20 @@ func (l *FixedLimiter) Allow(n int64) error { return nil } +// Value returns the current limiter value +func (l *FixedLimiter) Value() int64 { + l.mu.Lock() + defer l.mu.Unlock() + return l.value +} + +// Reset sets the limiter's value back to zero +func (l *FixedLimiter) Reset() { + l.mu.Lock() + defer l.mu.Unlock() + l.value = 0 +} + // RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit. type RateLimiter struct { limiter *rate.Limiter