mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-22 03:13:33 +01:00
Only set rate visitor if allowed
This commit is contained in:
parent
2329695a47
commit
bfc3983d06
4 changed files with 151 additions and 17 deletions
|
@ -112,7 +112,6 @@ const (
|
|||
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
||||
jsonBodyBytesLimit = 16384
|
||||
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
|
||||
rateTopicsWildcard = "*" // Allows defining all topics in the request subscriber-rate-limited topics
|
||||
)
|
||||
|
||||
// WebSocket constants
|
||||
|
@ -977,7 +976,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
}
|
||||
return nil
|
||||
}
|
||||
registerRateVisitors(topics, rateTopics, v)
|
||||
if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||
if poll {
|
||||
|
@ -1113,7 +1114,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
return conn.WriteJSON(msg)
|
||||
}
|
||||
registerRateVisitors(topics, rateTopics, v)
|
||||
if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
|
@ -1156,23 +1159,62 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
|||
return
|
||||
}
|
||||
|
||||
// registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic
|
||||
// will be rate limited against the rate visitor instead of the publishing visitor.
|
||||
// maybeSetRateVisitors sets the rate visitor on a topic (v.SetRateVisitor), indicating that all messages published
|
||||
// to that topic will be rate limited against the rate visitor instead of the publishing visitor.
|
||||
//
|
||||
// Setting the rate visitor is ony allowed if
|
||||
// - auth-file is not set (everything is open by default)
|
||||
// - the topic is reserved, and v.user is the owner
|
||||
// - the topic is not reserved, and v.user has write access
|
||||
//
|
||||
// Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition
|
||||
// until the Android app will send the "Rate-Topics" header.
|
||||
func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) {
|
||||
if len(rateTopics) == 1 && rateTopics[0] == rateTopicsWildcard {
|
||||
for _, t := range topics {
|
||||
t.SetRateVisitor(v)
|
||||
}
|
||||
} else {
|
||||
for _, t := range topics {
|
||||
if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) {
|
||||
t.SetRateVisitor(v)
|
||||
}
|
||||
func (s *Server) maybeSetRateVisitors(r *http.Request, v *visitor, topics []*topic, rateTopics []string) error {
|
||||
// Make a list of topics that we'll actually set the RateVisitor on
|
||||
eligibleRateTopics := make([]*topic, 0)
|
||||
for _, t := range topics {
|
||||
if strings.HasPrefix(t.ID, unifiedPushTopicPrefix) || util.Contains(rateTopics, t.ID) {
|
||||
eligibleRateTopics = append(eligibleRateTopics, t)
|
||||
}
|
||||
}
|
||||
if len(eligibleRateTopics) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If access controls are turned off, v has access to everything, and we can set the rate visitor
|
||||
if s.userManager == nil {
|
||||
return s.setRateVisitors(r, v, eligibleRateTopics)
|
||||
}
|
||||
|
||||
// If access controls are enabled, only set rate visitor if
|
||||
// - topic is reserved, and v.user is the owner
|
||||
// - topic is not reserved, and v.user has write access
|
||||
writableRateTopics := make([]*topic, 0)
|
||||
for _, t := range topics {
|
||||
ownerUserID, err := s.userManager.ReservationOwner(t.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ownerUserID == "" {
|
||||
if err := s.userManager.Authorize(v.User(), t.ID, user.PermissionWrite); err == nil {
|
||||
writableRateTopics = append(writableRateTopics, t)
|
||||
}
|
||||
} else if ownerUserID == v.MaybeUserID() {
|
||||
writableRateTopics = append(writableRateTopics, t)
|
||||
}
|
||||
}
|
||||
return s.setRateVisitors(r, v, writableRateTopics)
|
||||
}
|
||||
|
||||
func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topic) error {
|
||||
for _, t := range rateTopics {
|
||||
logvr(v, r).
|
||||
Tag(tagSubscribe).
|
||||
Field("message_topic", t.ID).
|
||||
Debug("Setting visitor as rate visitor for topic %s", t.ID)
|
||||
t.SetRateVisitor(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
|
||||
|
|
|
@ -2040,7 +2040,7 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
|
|||
r.RemoteAddr = "1.2.3.4"
|
||||
}
|
||||
rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{
|
||||
"rate-topics": "*",
|
||||
"rate-topics": "mytopic",
|
||||
}, subscriberFn)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String())
|
||||
|
@ -2065,6 +2065,72 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
|
|||
require.Nil(t, s.visitors["ip:1.2.3.4"])
|
||||
}
|
||||
|
||||
func TestServer_SubscriberRateLimiting_ProtectedTopics(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Create some ACLs
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessageLimit: 5,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("ben", "test"))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "announcements", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
|
||||
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "public_topic", user.PermissionReadWrite))
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "reserved-for-phil", user.PermissionReadWrite))
|
||||
|
||||
// Set rate visitor as user "phil" on topic
|
||||
// - "reserved-for-phil": Allowed, because I am the owner
|
||||
// - "public_topic": Allowed, because it has read-write permissions for everyone
|
||||
// - "announcements": NOT allowed, because it has read-only permissions for everyone
|
||||
rr := request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
"Rate-Topics": "reserved-for-phil,public_topic,announcements",
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name)
|
||||
require.Equal(t, "phil", s.topics["public_topic"].rateVisitor.user.Name)
|
||||
require.Nil(t, s.topics["announcements"].rateVisitor)
|
||||
|
||||
// Set rate visitor as user "ben" on topic
|
||||
// - "reserved-for-phil": NOT allowed, because I am not the owner
|
||||
// - "public_topic": Allowed, because it has read-write permissions for everyone
|
||||
// - "announcements": Allowed, because I have read-write permissions
|
||||
rr = request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
"Rate-Topics": "reserved-for-phil,public_topic,announcements",
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name)
|
||||
require.Equal(t, "ben", s.topics["public_topic"].rateVisitor.user.Name)
|
||||
require.Equal(t, "ben", s.topics["announcements"].rateVisitor.user.Name)
|
||||
}
|
||||
|
||||
func TestServer_SubscriberRateLimiting_ProtectedTopics_WithDefaultReadWrite(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionReadWrite
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Create some ACLs
|
||||
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
|
||||
|
||||
// Set rate visitor as ip:1.2.3.4 on topic
|
||||
// - "up1234": Allowed, because no ACLs and nobody owns the topic
|
||||
// - "announcements": NOT allowed, because it has read-only permissions for everyone
|
||||
rr := request(t, s, "GET", "/up1234,announcements/json?poll=1", "", nil, func(r *http.Request) {
|
||||
r.RemoteAddr = "1.2.3.4"
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "1.2.3.4", s.topics["up1234"].rateVisitor.ip.String())
|
||||
require.Nil(t, s.topics["announcements"].rateVisitor)
|
||||
}
|
||||
|
||||
func newTestConfig(t *testing.T) *Config {
|
||||
conf := NewConfig()
|
||||
conf.BaseURL = "http://127.0.0.1:12345"
|
||||
|
|
|
@ -141,6 +141,7 @@ func (v *visitor) Context() log.Context {
|
|||
func (v *visitor) contextNoLock() log.Context {
|
||||
info := v.infoLightNoLock()
|
||||
fields := log.Context{
|
||||
"visitor_id": visitorID(v.ip, v.user),
|
||||
"visitor_ip": v.ip.String(),
|
||||
"visitor_messages": info.Stats.Messages,
|
||||
"visitor_messages_limit": info.Limits.MessageLimit,
|
||||
|
|
|
@ -201,7 +201,14 @@ const (
|
|||
selectUserReservationsCountQuery = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
selectUserReservationsOwnerQuery = `
|
||||
SELECT owner_user_id
|
||||
FROM user_access
|
||||
WHERE topic = ?
|
||||
AND user_id = owner_user_id
|
||||
`
|
||||
selectUserHasReservationQuery = `
|
||||
SELECT COUNT(*)
|
||||
|
@ -1025,6 +1032,24 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
|
|||
return count, nil
|
||||
}
|
||||
|
||||
// ReservationOwner returns user ID of the user that owns this topic, or an
|
||||
// empty string if it's not owned by anyone
|
||||
func (a *Manager) ReservationOwner(topic string) (string, error) {
|
||||
rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return "", nil
|
||||
}
|
||||
var ownerUserID string
|
||||
if err := rows.Scan(&ownerUserID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ownerUserID, nil
|
||||
}
|
||||
|
||||
// ChangePassword changes a user's password
|
||||
func (a *Manager) ChangePassword(username, password string) error {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
|
||||
|
|
Loading…
Reference in a new issue