1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-11-22 19:33:27 +01:00

Prune excess tokens per user

This commit is contained in:
binwiederhier 2023-01-05 20:22:34 -05:00
parent 60f1882bec
commit 7fa63c8e19
3 changed files with 108 additions and 29 deletions

View file

@ -45,10 +45,6 @@ import (
"account topic" sync mechanism
purge accounts that were not logged into in X
reset daily limits for users
max token issue limit
user db startup queries -> foreign keys
UI
- Feature flag for "reserve topic" feature
Sync:
- "mute" setting
- figure out what settings are "web" or "phone"

View file

@ -15,16 +15,18 @@ import (
)
const (
tokenLength = 32
bcryptCost = 10
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
userStatsQueueWriterInterval = 33 * time.Second
userTokenExpiryDuration = 72 * time.Hour
tokenLength = 32
tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
tokenMaxCount = 10 // Only keep this many tokens in the table per user
)
var (
errNoTokenProvided = errors.New("no token provided")
errTopicOwnedByOthers = errors.New("topic owned by others")
errNoRows = errors.New("no rows found")
)
// Manager-related queries
@ -139,7 +141,7 @@ const (
ORDER BY a_user.topic
`
selectOtherAccessCountQuery = `
SELECT count(*)
SELECT COUNT(*)
FROM user_access
WHERE (topic = ? OR ? LIKE topic)
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
@ -148,10 +150,22 @@ const (
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
deleteExcessTokensQuery = `
DELETE FROM user_token
WHERE (user_id, token) NOT IN (
SELECT user_id, token
FROM user_token
WHERE user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY expires DESC
LIMIT ?
)
;
`
)
// Schema management queries
@ -182,22 +196,21 @@ const (
// Manager is an implementation of Manager. It stores users and access control list
// in a SQLite database.
type Manager struct {
db *sql.DB
defaultAccess Permission // Default permission if no ACL matches
statsQueue map[string]*User // Username -> User, for "unimportant" user updates
tokenExpiryInterval time.Duration // Duration after which tokens expire, and by which tokens are extended
mu sync.Mutex
db *sql.DB
defaultAccess Permission // Default permission if no ACL matches
statsQueue map[string]*User // Username -> User, for "unimportant" user updates
mu sync.Mutex
}
var _ Auther = (*Manager)(nil)
// NewManager creates a new Manager instance
func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
return newManager(filename, startupQueries, defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval)
return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
}
// NewManager creates a new Manager instance
func newManager(filename, startupQueries string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) {
func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
@ -209,10 +222,9 @@ func newManager(filename, startupQueries string, defaultAccess Permission, token
return nil, err
}
manager := &Manager{
db: db,
defaultAccess: defaultAccess,
statsQueue: make(map[string]*User),
tokenExpiryInterval: tokenExpiryDuration,
db: db,
defaultAccess: defaultAccess,
statsQueue: make(map[string]*User),
}
go manager.userStatsQueueWriter(statsWriterInterval)
return manager, nil
@ -253,10 +265,38 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
}
// CreateToken generates a random token for the given user and returns it. The token expires
// after a fixed duration unless ExtendToken is called.
// after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
// given user, if there are too many of them.
func (a *Manager) CreateToken(user *User) (*Token, error) {
token, expires := util.RandomString(tokenLength), time.Now().Add(userTokenExpiryDuration)
if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration)
tx, err := a.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
return nil, err
}
rows, err := tx.Query(selectTokenCountQuery, user.Name)
if err != nil {
return nil, err
}
defer rows.Close()
if !rows.Next() {
return nil, errNoRows
}
var tokenCount int
if err := rows.Scan(&tokenCount); err != nil {
return nil, err
}
if tokenCount >= tokenMaxCount {
// This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
// on two indices, whereas the query below is a full table scan.
if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return &Token{
@ -270,7 +310,7 @@ func (a *Manager) ExtendToken(user *User) (*Token, error) {
if user.Token == "" {
return nil, errNoTokenProvided
}
newExpires := time.Now().Add(userTokenExpiryDuration)
newExpires := time.Now().Add(tokenExpiryDuration)
if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
return nil, err
}
@ -600,7 +640,7 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error {
}
defer rows.Close()
if !rows.Next() {
return errors.New("no rows found")
return errNoRows
}
var otherCount int
if err := rows.Scan(&otherCount); err != nil {

View file

@ -369,8 +369,51 @@ func TestManager_Token_Extend(t *testing.T) {
require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix())
}
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// Try to extend token for user without token
u, err := a.User("ben")
require.Nil(t, err)
// Tokens
baseTime := time.Now().Add(24 * time.Hour)
tokens := make([]string, 0)
for i := 0; i < 12; i++ {
token, err := a.CreateToken(u)
require.Nil(t, err)
require.NotEmpty(t, token.Value)
tokens = append(tokens, token.Value)
// Manually modify expiry date to avoid sorting issues (this is a hack)
_, err = a.db.Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value)
require.Nil(t, err)
}
_, err = a.AuthenticateToken(tokens[0])
require.Equal(t, ErrUnauthenticated, err)
_, err = a.AuthenticateToken(tokens[1])
require.Equal(t, ErrUnauthenticated, err)
for i := 2; i < 12; i++ {
userWithToken, err := a.AuthenticateToken(tokens[i])
require.Nil(t, err, "token[%d]=%s failed", i, tokens[i])
require.Equal(t, "ben", userWithToken.Name)
require.Equal(t, tokens[i], userWithToken.Token)
}
var count int
rows, err := a.db.Query(`SELECT COUNT(*) FROM user_token`)
require.Nil(t, err)
require.True(t, rows.Next())
require.Nil(t, rows.Scan(&count))
require.Equal(t, 10, count)
}
func TestManager_EnqueueStats(t *testing.T) {
a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond)
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
require.Nil(t, err)
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
@ -400,7 +443,7 @@ func TestManager_EnqueueStats(t *testing.T) {
}
func TestManager_ChangeSettings(t *testing.T) {
a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond)
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
require.Nil(t, err)
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
@ -482,7 +525,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
require.Nil(t, err)
// Create manager to trigger migration
a := newTestManagerFromFile(t, filename, PermissionDenyAll, userTokenExpiryDuration, userStatsQueueWriterInterval)
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, userStatsQueueWriterInterval)
checkSchemaVersion(t, a.db)
users, err := a.Users()
@ -530,11 +573,11 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
}
func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval)
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, userStatsQueueWriterInterval)
}
func newTestManagerFromFile(t *testing.T, filename string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager {
a, err := newManager(filename, defaultAccess, tokenExpiryDuration, statsWriterInterval)
func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) *Manager {
a, err := newManager(filename, startupQueries, defaultAccess, statsWriterInterval)
require.Nil(t, err)
return a
}