mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-05-28 01:15:43 +02:00
Rate limits make sense now!
This commit is contained in:
parent
a036814d98
commit
c874a641df
17 changed files with 365 additions and 205 deletions
server
|
@ -8,7 +8,6 @@ import (
|
|||
"fmt"
|
||||
"heckel.io/ntfy/user"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -22,9 +21,14 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
func TestServer_PublishAndPoll(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
|
@ -742,16 +746,31 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
|
|||
require.Equal(t, 401, response.Code)
|
||||
}
|
||||
|
||||
func TestServer_StatsResetter(t *testing.T) {
|
||||
func TestServer_StatsResetter_User_Without_Tier(t *testing.T) {
|
||||
// This tests the stats resetter for
|
||||
// - an anonymous user
|
||||
// - a user without a tier (treated like the same as the anonymous user)
|
||||
// - a user with a tier
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
c.VisitorStatsResetTime = time.Now().Add(2 * time.Second)
|
||||
s := newTestServer(t, c)
|
||||
go s.runStatsResetter()
|
||||
|
||||
// Create user with tier (tieruser) and user without tier (phil)
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessageLimit: 5,
|
||||
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("tieruser", "test"))
|
||||
|
||||
// Send an anonymous message
|
||||
response := request(t, s, "PUT", "/mytopic", "test", nil)
|
||||
|
||||
// Send messages from user without tier (phil)
|
||||
for i := 0; i < 5; i++ {
|
||||
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -759,30 +778,66 @@ func TestServer_StatsResetter(t *testing.T) {
|
|||
require.Equal(t, 200, response.Code)
|
||||
}
|
||||
|
||||
response := request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
// Send messages from user with tier
|
||||
for i := 0; i < 2; i++ {
|
||||
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
||||
"Authorization": util.BasicAuth("tieruser", "tieruser"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
}
|
||||
|
||||
// User stats show 10 messages
|
||||
// User stats show 6 messages (for user without tier)
|
||||
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(5), account.Stats.Messages)
|
||||
require.Equal(t, int64(6), account.Stats.Messages)
|
||||
|
||||
// User stats show 6 messages (for anonymous visitor)
|
||||
response = request(t, s, "GET", "/v1/account", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(6), account.Stats.Messages)
|
||||
|
||||
// User stats show 2 messages (for user with tier)
|
||||
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("tieruser", "tieruser"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(2), account.Stats.Messages)
|
||||
|
||||
// Wait for stats resetter to run
|
||||
time.Sleep(2200 * time.Millisecond)
|
||||
|
||||
// User stats show 0 messages now!
|
||||
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), account.Stats.Messages)
|
||||
|
||||
// Since this is a user without a tier, the anonymous user should have the same stats
|
||||
response = request(t, s, "GET", "/v1/account", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), account.Stats.Messages)
|
||||
|
||||
// User stats show 0 messages (for user with tier)
|
||||
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("tieruser", "tieruser"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), account.Stats.Messages)
|
||||
}
|
||||
|
||||
type testMailer struct {
|
||||
|
@ -1133,9 +1188,9 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
|
|||
|
||||
// Create tier with certain limits
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessagesLimit: 5,
|
||||
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||
Code: "test",
|
||||
MessageLimit: 5,
|
||||
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
@ -1363,8 +1418,8 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
|
|||
sevenDays := time.Duration(604800) * time.Second
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessagesLimit: 10,
|
||||
MessagesExpiryDuration: sevenDays,
|
||||
MessageLimit: 10,
|
||||
MessageExpiryDuration: sevenDays,
|
||||
AttachmentFileSizeLimit: 50_000,
|
||||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: sevenDays, // 7 days
|
||||
|
@ -1407,8 +1462,8 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
|
|||
// Create tier with certain limits
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessagesLimit: 10,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
MessageLimit: 10,
|
||||
MessageExpiryDuration: time.Hour,
|
||||
AttachmentFileSizeLimit: 50_000,
|
||||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: time.Hour,
|
||||
|
@ -1450,7 +1505,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
|
|||
// Create tier with certain limits
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessagesLimit: 100,
|
||||
MessageLimit: 100,
|
||||
AttachmentFileSizeLimit: 50_000,
|
||||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: 30 * time.Second,
|
||||
|
@ -1574,7 +1629,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
|
|||
r, _ := http.NewRequest("GET", "/bla", nil)
|
||||
r.RemoteAddr = "8.9.10.11"
|
||||
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
|
||||
v, err := s.visitor(r)
|
||||
v, err := s.maybeAuthenticate(r)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "8.9.10.11", v.ip.String())
|
||||
}
|
||||
|
@ -1586,7 +1641,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
|
|||
r, _ := http.NewRequest("GET", "/bla", nil)
|
||||
r.RemoteAddr = "8.9.10.11"
|
||||
r.Header.Set("X-Forwarded-For", "1.1.1.1")
|
||||
v, err := s.visitor(r)
|
||||
v, err := s.maybeAuthenticate(r)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "1.1.1.1", v.ip.String())
|
||||
}
|
||||
|
@ -1598,7 +1653,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
|
|||
r, _ := http.NewRequest("GET", "/bla", nil)
|
||||
r.RemoteAddr = "8.9.10.11"
|
||||
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
|
||||
v, err := s.visitor(r)
|
||||
v, err := s.maybeAuthenticate(r)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "234.5.2.1", v.ip.String())
|
||||
}
|
||||
|
@ -1611,7 +1666,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
|||
s := newTestServer(t, c)
|
||||
|
||||
// Add lots of messages
|
||||
log.Printf("Adding %d messages", count)
|
||||
log.Info("Adding %d messages", count)
|
||||
start := time.Now()
|
||||
messages := make([]*message, 0)
|
||||
for i := 0; i < count; i++ {
|
||||
|
@ -1621,31 +1676,31 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
|||
messages = append(messages, newDefaultMessage(topicID, "some message"))
|
||||
}
|
||||
require.Nil(t, s.messageCache.addMessages(messages))
|
||||
log.Printf("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
|
||||
log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
|
||||
|
||||
// Update stats
|
||||
statsChan := make(chan bool)
|
||||
go func() {
|
||||
log.Printf("Updating stats")
|
||||
log.Info("Updating stats")
|
||||
start := time.Now()
|
||||
s.execManager()
|
||||
log.Printf("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
|
||||
log.Info("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
|
||||
statsChan <- true
|
||||
}()
|
||||
time.Sleep(50 * time.Millisecond) // Make sure it starts first
|
||||
|
||||
// Publish message (during stats update)
|
||||
log.Printf("Publishing message")
|
||||
log.Info("Publishing message")
|
||||
start = time.Now()
|
||||
response := request(t, s, "PUT", "/mytopic", "some body", nil)
|
||||
m := toMessage(t, response.Body.String())
|
||||
assert.Equal(t, "some body", m.Message)
|
||||
assert.True(t, time.Since(start) < 100*time.Millisecond)
|
||||
log.Printf("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
|
||||
log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
|
||||
|
||||
// Wait for all goroutines
|
||||
<-statsChan
|
||||
log.Printf("Done: Waiting for all locks")
|
||||
log.Info("Done: Waiting for all locks")
|
||||
}
|
||||
|
||||
func newTestConfig(t *testing.T) *Config {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue