1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-22 20:38:08 +02:00

Add "last access" to access tokens

This commit is contained in:
binwiederhier 2023-01-28 20:29:06 -05:00
parent 000bf27c87
commit e596834096
15 changed files with 276 additions and 145 deletions

View file

@ -10,6 +10,7 @@ import (
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
"net/netip"
"strings"
"sync"
"time"
@ -95,6 +96,8 @@ const (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
@ -127,9 +130,9 @@ const (
selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND (t.expires = 0 OR t.expires >= ?)
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
`
selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
@ -218,16 +221,17 @@ const (
AND topic = ?
`
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
selectTokensQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ?`
selectTokenQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ? AND token = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, label, expires) VALUES (?, ?, ?, ?)`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
deleteExcessTokensQuery = `
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
deleteExcessTokensQuery = `
DELETE FROM user_token
WHERE (user_id, token) NOT IN (
SELECT user_id, token
@ -297,16 +301,17 @@ const (
// in a SQLite database.
type Manager struct {
db *sql.DB
defaultAccess Permission // Default permission if no ACL matches
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
bcryptCost int // Makes testing easier
defaultAccess Permission // Default permission if no ACL matches
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
bcryptCost int // Makes testing easier
mu sync.Mutex
}
var _ Auther = (*Manager)(nil)
// NewManager creates a new Manager instance
func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) (*Manager, error) {
func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, queueWriterInterval time.Duration) (*Manager, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
@ -321,9 +326,10 @@ func NewManager(filename, startupQueries string, defaultAccess Permission, bcryp
db: db,
defaultAccess: defaultAccess,
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
bcryptCost: bcryptCost,
}
go manager.userStatsQueueWriter(statsWriterInterval)
go manager.asyncQueueWriter(queueWriterInterval)
return manager, nil
}
@ -367,14 +373,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
// CreateToken generates a random token for the given user and returns it. The token expires
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
// given user, if there are too many of them.
func (a *Manager) CreateToken(userID, label string, expires time.Time) (*Token, error) {
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
token := util.RandomStringPrefix(tokenPrefix, tokenLength)
tx, err := a.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
if _, err := tx.Exec(insertTokenQuery, userID, token, label, expires.Unix()); err != nil {
access := time.Now()
if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
return nil, err
}
rows, err := tx.Query(selectTokenCountQuery, userID)
@ -400,9 +407,11 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time) (*Token,
return nil, err
}
return &Token{
Value: token,
Label: label,
Expires: expires,
Value: token,
Label: label,
LastAccess: access,
LastOrigin: origin,
Expires: expires,
}, nil
}
@ -437,20 +446,26 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
}
func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
var token, label string
var expires int64
var token, label, lastOrigin string
var lastAccess, expires int64
if !rows.Next() {
return nil, ErrTokenNotFound
}
if err := rows.Scan(&token, &label, &expires); err != nil {
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
lastOriginIP, err := netip.ParseAddr(lastOrigin)
if err != nil {
lastOriginIP = netip.IPv4Unspecified()
}
return &Token{
Value: token,
Label: label,
Expires: time.Unix(expires, 0),
Value: token,
Label: label,
LastAccess: time.Unix(lastAccess, 0),
LastOrigin: lastOriginIP,
Expires: time.Unix(expires, 0),
}, nil
}
@ -521,7 +536,7 @@ func (a *Manager) ChangeSettings(user *User) error {
// ResetStats resets all user stats in the user database. This touches all users.
func (a *Manager) ResetStats() error {
a.mu.Lock()
a.mu.Lock() // Includes database query to avoid races!
defer a.mu.Unlock()
if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
return err
@ -538,12 +553,23 @@ func (a *Manager) EnqueueStats(userID string, stats *Stats) {
a.statsQueue[userID] = stats
}
func (a *Manager) userStatsQueueWriter(interval time.Duration) {
// EnqueueTokenUpdate adds the token update to a queue which writes out token access times
// in batches at a regular interval
func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
a.mu.Lock()
defer a.mu.Unlock()
a.tokenQueue[tokenID] = update
}
func (a *Manager) asyncQueueWriter(interval time.Duration) {
ticker := time.NewTicker(interval)
for range ticker.C {
if err := a.writeUserStatsQueue(); err != nil {
log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
}
if err := a.writeTokenUpdateQueue(); err != nil {
log.Warn("User Manager: Writing token update queue failed: %s", err.Error())
}
}
}
@ -572,6 +598,31 @@ func (a *Manager) writeUserStatsQueue() error {
return tx.Commit()
}
func (a *Manager) writeTokenUpdateQueue() error {
a.mu.Lock()
if len(a.tokenQueue) == 0 {
a.mu.Unlock()
log.Trace("User Manager: No token updates to commit")
return nil
}
tokenQueue := a.tokenQueue
a.tokenQueue = make(map[string]*TokenUpdate)
a.mu.Unlock()
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
log.Debug("User Manager: Writing token update queue for %d token(s)", len(tokenQueue))
for tokenID, update := range tokenQueue {
log.Trace("User Manager: Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
return err
}
}
return tx.Commit()
}
// Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user.
func (a *Manager) Authorize(user *User, topic string, perm Permission) error {