Associate messages with a user

This commit is contained in:
binwiederhier 2022-12-19 21:42:36 -05:00
parent 84785b7a60
commit 2b78a8cb51
7 changed files with 71 additions and 41 deletions

View File

@ -40,6 +40,7 @@ const (
attachment_expires INT NOT NULL, attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL, attachment_url TEXT NOT NULL,
sender TEXT NOT NULL, sender TEXT NOT NULL,
user TEXT NOT NULL,
encoding TEXT NOT NULL, encoding TEXT NOT NULL,
published INT NOT NULL published INT NOT NULL
); );
@ -49,37 +50,37 @@ const (
COMMIT; COMMIT;
` `
insertMessageQuery = ` insertMessageQuery = `
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published) INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1` pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesSinceTimeQuery = ` selectMessagesSinceTimeQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE topic = ? AND time >= ? AND published = 1 WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id ORDER BY time, id
` `
selectMessagesSinceTimeIncludeScheduledQuery = ` selectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE topic = ? AND time >= ? WHERE topic = ? AND time >= ?
ORDER BY time, id ORDER BY time, id
` `
selectMessagesSinceIDQuery = ` selectMessagesSinceIDQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE topic = ? AND id > ? AND published = 1 WHERE topic = ? AND id > ? AND published = 1
ORDER BY time, id ORDER BY time, id
` `
selectMessagesSinceIDIncludeScheduledQuery = ` selectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE topic = ? AND (id > ? OR published = 0) WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id ORDER BY time, id
` `
selectMessagesDueQuery = ` selectMessagesDueQuery = `
SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE time <= ? AND published = 0 WHERE time <= ? AND published = 0
ORDER BY time, id ORDER BY time, id
@ -88,7 +89,8 @@ const (
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
) )
// Schema management queries // Schema management queries
@ -316,6 +318,7 @@ func (c *messageCache) addMessages(ms []*message) error {
attachmentExpires, attachmentExpires,
attachmentURL, attachmentURL,
sender, sender,
m.User,
m.Encoding, m.Encoding,
published, published,
) )
@ -442,11 +445,23 @@ func (c *messageCache) Prune(olderThan time.Time) error {
return nil return nil
} }
func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) { func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeQuery, sender, time.Now().Unix()) rows, err := c.db.Query(selectAttachmentsSizeBySenderQuery, sender, time.Now().Unix())
if err != nil { if err != nil {
return 0, err return 0, err
} }
return c.readAttachmentBytesUsed(rows)
}
func (c *messageCache) AttachmentBytesUsedByUser(user string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeByUserQuery, user, time.Now().Unix())
if err != nil {
return 0, err
}
return c.readAttachmentBytesUsed(rows)
}
func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
defer rows.Close() defer rows.Close()
var size int64 var size int64
if !rows.Next() { if !rows.Next() {
@ -477,7 +492,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
for rows.Next() { for rows.Next() {
var timestamp, attachmentSize, attachmentExpires int64 var timestamp, attachmentSize, attachmentExpires int64
var priority int var priority int
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, encoding string var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string
err := rows.Scan( err := rows.Scan(
&id, &id,
&timestamp, &timestamp,
@ -495,6 +510,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
&attachmentExpires, &attachmentExpires,
&attachmentURL, &attachmentURL,
&sender, &sender,
&user,
&encoding, &encoding,
) )
if err != nil { if err != nil {
@ -538,6 +554,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
Actions: actions, Actions: actions,
Attachment: att, Attachment: att,
Sender: senderIP, // Must parse assuming database must be correct Sender: senderIP, // Must parse assuming database must be correct
User: user,
Encoding: encoding, Encoding: encoding,
}) })
} }
@ -598,6 +615,7 @@ func setupCacheDB(db *sql.DB, startupQueries string) error {
} else if schemaVersion == 8 { } else if schemaVersion == 8 {
return migrateFrom8(db) return migrateFrom8(db)
} }
// TODO add user column
return fmt.Errorf("unexpected schema version found: %d", schemaVersion) return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
} }

View File

@ -343,11 +343,11 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Sender.String()) require.Equal(t, "1.2.3.4", messages[1].Sender.String())
size, err := c.AttachmentBytesUsed("1.2.3.4") size, err := c.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(30000), size) require.Equal(t, int64(30000), size)
size, err = c.AttachmentBytesUsed("5.6.7.8") size, err = c.AttachmentBytesUsedBySender("5.6.7.8")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), size) require.Equal(t, int64(0), size)
} }

View File

@ -495,6 +495,10 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
if m.PollID != "" { if m.PollID != "" {
m = newPollRequestMessage(t.ID, m.PollID) m = newPollRequestMessage(t.ID, m.PollID)
} }
if v.user != nil {
log.Info("user is %s", v.user.Name)
m.User = v.user.Name
}
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
return nil, err return nil, err
} }
@ -502,8 +506,8 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
m.Message = emptyMessageBody m.Message = emptyMessageBody
} }
delayed := m.Time > time.Now().Unix() delayed := m.Time > time.Now().Unix()
log.Debug("%s Received message: event=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s", log.Debug("%s Received message: event=%s, user=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s",
logMessagePrefix(v, m), m.Event, len(m.Message), delayed, firebase, cache, unifiedpush, email) logMessagePrefix(v, m), m.Event, m.User, len(m.Message), delayed, firebase, cache, unifiedpush, email)
if log.IsTrace() { if log.IsTrace() {
log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m)) log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m))
} }

View File

@ -75,8 +75,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
Code: v.user.Plan.Code, Code: v.user.Plan.Code,
Upgradable: v.user.Plan.Upgradable, Upgradable: v.user.Plan.Upgradable,
} }
} else { } else if v.user.Role == auth.RoleAdmin {
if v.user.Role == auth.RoleAdmin {
response.Plan = &apiAccountPlan{ response.Plan = &apiAccountPlan{
Code: string(auth.PlanUnlimited), Code: string(auth.PlanUnlimited),
Upgradable: false, Upgradable: false,
@ -87,7 +86,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
Upgradable: true, Upgradable: true,
} }
} }
}
} else { } else {
response.Username = auth.Everyone response.Username = auth.Everyone
response.Role = string(auth.RoleAnonymous) response.Role = string(auth.RoleAnonymous)

View File

@ -1151,7 +1151,7 @@ func TestServer_PublishAttachment(t *testing.T) {
require.Equal(t, "", response.Body.String()) require.Equal(t, "", response.Body.String())
// Slightly unrelated cross-test: make sure we add an owner for internal attachments // Slightly unrelated cross-test: make sure we add an owner for internal attachments
size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request() size, err := s.messageCache.AttachmentBytesUsedBySender("9.9.9.9") // See request()
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(5000), size) require.Equal(t, int64(5000), size)
} }
@ -1180,7 +1180,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
require.Equal(t, content, response.Body.String()) require.Equal(t, content, response.Body.String())
// Slightly unrelated cross-test: make sure we add an owner for internal attachments // Slightly unrelated cross-test: make sure we add an owner for internal attachments
size, err := s.messageCache.AttachmentBytesUsed("1.2.3.4") size, err := s.messageCache.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(21), size) require.Equal(t, int64(21), size)
} }
@ -1200,7 +1200,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
require.Equal(t, netip.Addr{}, msg.Sender) require.Equal(t, netip.Addr{}, msg.Sender)
// Slightly unrelated cross-test: make sure we don't add an owner for external attachments // Slightly unrelated cross-test: make sure we don't add an owner for external attachments
size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1") size, err := s.messageCache.AttachmentBytesUsedBySender("127.0.0.1")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), size) require.Equal(t, int64(0), size)
} }

View File

@ -36,8 +36,9 @@ type message struct {
Actions []*action `json:"actions,omitempty"` Actions []*action `json:"actions,omitempty"`
Attachment *attachment `json:"attachment,omitempty"` Attachment *attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"` PollID string `json:"poll_id,omitempty"`
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
User string `json:"-"` // Username of the uploader, used to associated attachments
} }
type attachment struct { type attachment struct {

View File

@ -151,12 +151,10 @@ func (v *visitor) IncrEmails() {
} }
func (v *visitor) Stats() (*visitorStats, error) { func (v *visitor) Stats() (*visitorStats, error) {
attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String())
if err != nil {
return nil, err
}
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() messages := v.messages
emails := v.emails
v.mu.Unlock()
stats := &visitorStats{} stats := &visitorStats{}
if v.user != nil && v.user.Role == auth.RoleAdmin { if v.user != nil && v.user.Role == auth.RoleAdmin {
stats.Basis = "role" stats.Basis = "role"
@ -174,12 +172,22 @@ func (v *visitor) Stats() (*visitorStats, error) {
stats.Basis = "ip" stats.Basis = "ip"
stats.MessagesLimit = replenishDurationToDailyLimit(v.config.VisitorRequestLimitReplenish) stats.MessagesLimit = replenishDurationToDailyLimit(v.config.VisitorRequestLimitReplenish)
stats.EmailsLimit = replenishDurationToDailyLimit(v.config.VisitorEmailLimitReplenish) stats.EmailsLimit = replenishDurationToDailyLimit(v.config.VisitorEmailLimitReplenish)
stats.AttachmentTotalSizeLimit = v.config.AttachmentTotalSizeLimit stats.AttachmentTotalSizeLimit = v.config.VisitorAttachmentTotalSizeLimit
stats.AttachmentFileSizeLimit = v.config.AttachmentFileSizeLimit stats.AttachmentFileSizeLimit = v.config.AttachmentFileSizeLimit
} }
stats.Messages = v.messages var attachmentsBytesUsed int64
var err error
if v.user != nil {
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.Name)
} else {
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String())
}
if err != nil {
return nil, err
}
stats.Messages = messages
stats.MessagesRemaining = zeroIfNegative(stats.MessagesLimit - stats.MessagesLimit) stats.MessagesRemaining = zeroIfNegative(stats.MessagesLimit - stats.MessagesLimit)
stats.Emails = v.emails stats.Emails = emails
stats.EmailsRemaining = zeroIfNegative(stats.EmailsLimit - stats.EmailsRemaining) stats.EmailsRemaining = zeroIfNegative(stats.EmailsLimit - stats.EmailsRemaining)
stats.AttachmentTotalSize = attachmentsBytesUsed stats.AttachmentTotalSize = attachmentsBytesUsed
stats.AttachmentTotalSizeRemaining = zeroIfNegative(stats.AttachmentTotalSizeLimit - stats.AttachmentTotalSize) stats.AttachmentTotalSizeRemaining = zeroIfNegative(stats.AttachmentTotalSizeLimit - stats.AttachmentTotalSize)