mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-26 21:25:18 +01:00
Prune excess tokens per user
This commit is contained in:
parent
60f1882bec
commit
7fa63c8e19
3 changed files with 108 additions and 29 deletions
|
@ -45,10 +45,6 @@ import (
|
||||||
"account topic" sync mechanism
|
"account topic" sync mechanism
|
||||||
purge accounts that were not logged into in X
|
purge accounts that were not logged into in X
|
||||||
reset daily limits for users
|
reset daily limits for users
|
||||||
max token issue limit
|
|
||||||
user db startup queries -> foreign keys
|
|
||||||
UI
|
|
||||||
- Feature flag for "reserve topic" feature
|
|
||||||
Sync:
|
Sync:
|
||||||
- "mute" setting
|
- "mute" setting
|
||||||
- figure out what settings are "web" or "phone"
|
- figure out what settings are "web" or "phone"
|
||||||
|
|
|
@ -15,16 +15,18 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
tokenLength = 32
|
|
||||||
bcryptCost = 10
|
bcryptCost = 10
|
||||||
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
|
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
|
||||||
userStatsQueueWriterInterval = 33 * time.Second
|
userStatsQueueWriterInterval = 33 * time.Second
|
||||||
userTokenExpiryDuration = 72 * time.Hour
|
tokenLength = 32
|
||||||
|
tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
|
||||||
|
tokenMaxCount = 10 // Only keep this many tokens in the table per user
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errNoTokenProvided = errors.New("no token provided")
|
errNoTokenProvided = errors.New("no token provided")
|
||||||
errTopicOwnedByOthers = errors.New("topic owned by others")
|
errTopicOwnedByOthers = errors.New("topic owned by others")
|
||||||
|
errNoRows = errors.New("no rows found")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager-related queries
|
// Manager-related queries
|
||||||
|
@ -139,7 +141,7 @@ const (
|
||||||
ORDER BY a_user.topic
|
ORDER BY a_user.topic
|
||||||
`
|
`
|
||||||
selectOtherAccessCountQuery = `
|
selectOtherAccessCountQuery = `
|
||||||
SELECT count(*)
|
SELECT COUNT(*)
|
||||||
FROM user_access
|
FROM user_access
|
||||||
WHERE (topic = ? OR ? LIKE topic)
|
WHERE (topic = ? OR ? LIKE topic)
|
||||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
||||||
|
@ -148,10 +150,22 @@ const (
|
||||||
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
|
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
|
||||||
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
|
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
|
||||||
|
|
||||||
|
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)`
|
||||||
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
|
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
|
||||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
|
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
|
||||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
|
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
|
||||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
|
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
|
||||||
|
deleteExcessTokensQuery = `
|
||||||
|
DELETE FROM user_token
|
||||||
|
WHERE (user_id, token) NOT IN (
|
||||||
|
SELECT user_id, token
|
||||||
|
FROM user_token
|
||||||
|
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
|
ORDER BY expires DESC
|
||||||
|
LIMIT ?
|
||||||
|
)
|
||||||
|
;
|
||||||
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
// Schema management queries
|
// Schema management queries
|
||||||
|
@ -182,22 +196,21 @@ const (
|
||||||
// Manager is an implementation of Manager. It stores users and access control list
|
// Manager is an implementation of Manager. It stores users and access control list
|
||||||
// in a SQLite database.
|
// in a SQLite database.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
defaultAccess Permission // Default permission if no ACL matches
|
defaultAccess Permission // Default permission if no ACL matches
|
||||||
statsQueue map[string]*User // Username -> User, for "unimportant" user updates
|
statsQueue map[string]*User // Username -> User, for "unimportant" user updates
|
||||||
tokenExpiryInterval time.Duration // Duration after which tokens expire, and by which tokens are extended
|
mu sync.Mutex
|
||||||
mu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Auther = (*Manager)(nil)
|
var _ Auther = (*Manager)(nil)
|
||||||
|
|
||||||
// NewManager creates a new Manager instance
|
// NewManager creates a new Manager instance
|
||||||
func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
|
func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
|
||||||
return newManager(filename, startupQueries, defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval)
|
return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new Manager instance
|
// NewManager creates a new Manager instance
|
||||||
func newManager(filename, startupQueries string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) {
|
func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
|
||||||
db, err := sql.Open("sqlite3", filename)
|
db, err := sql.Open("sqlite3", filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -209,10 +222,9 @@ func newManager(filename, startupQueries string, defaultAccess Permission, token
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
db: db,
|
db: db,
|
||||||
defaultAccess: defaultAccess,
|
defaultAccess: defaultAccess,
|
||||||
statsQueue: make(map[string]*User),
|
statsQueue: make(map[string]*User),
|
||||||
tokenExpiryInterval: tokenExpiryDuration,
|
|
||||||
}
|
}
|
||||||
go manager.userStatsQueueWriter(statsWriterInterval)
|
go manager.userStatsQueueWriter(statsWriterInterval)
|
||||||
return manager, nil
|
return manager, nil
|
||||||
|
@ -253,10 +265,38 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateToken generates a random token for the given user and returns it. The token expires
|
// CreateToken generates a random token for the given user and returns it. The token expires
|
||||||
// after a fixed duration unless ExtendToken is called.
|
// after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
|
||||||
|
// given user, if there are too many of them.
|
||||||
func (a *Manager) CreateToken(user *User) (*Token, error) {
|
func (a *Manager) CreateToken(user *User) (*Token, error) {
|
||||||
token, expires := util.RandomString(tokenLength), time.Now().Add(userTokenExpiryDuration)
|
token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration)
|
||||||
if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
|
tx, err := a.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rows, err := tx.Query(selectTokenCountQuery, user.Name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return nil, errNoRows
|
||||||
|
}
|
||||||
|
var tokenCount int
|
||||||
|
if err := rows.Scan(&tokenCount); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tokenCount >= tokenMaxCount {
|
||||||
|
// This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
|
||||||
|
// on two indices, whereas the query below is a full table scan.
|
||||||
|
if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Token{
|
return &Token{
|
||||||
|
@ -270,7 +310,7 @@ func (a *Manager) ExtendToken(user *User) (*Token, error) {
|
||||||
if user.Token == "" {
|
if user.Token == "" {
|
||||||
return nil, errNoTokenProvided
|
return nil, errNoTokenProvided
|
||||||
}
|
}
|
||||||
newExpires := time.Now().Add(userTokenExpiryDuration)
|
newExpires := time.Now().Add(tokenExpiryDuration)
|
||||||
if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
|
if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -600,7 +640,7 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error {
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
return errors.New("no rows found")
|
return errNoRows
|
||||||
}
|
}
|
||||||
var otherCount int
|
var otherCount int
|
||||||
if err := rows.Scan(&otherCount); err != nil {
|
if err := rows.Scan(&otherCount); err != nil {
|
||||||
|
|
|
@ -369,8 +369,51 @@ func TestManager_Token_Extend(t *testing.T) {
|
||||||
require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix())
|
require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||||
|
a := newTestManager(t, PermissionDenyAll)
|
||||||
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
|
|
||||||
|
// Try to extend token for user without token
|
||||||
|
u, err := a.User("ben")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Tokens
|
||||||
|
baseTime := time.Now().Add(24 * time.Hour)
|
||||||
|
tokens := make([]string, 0)
|
||||||
|
for i := 0; i < 12; i++ {
|
||||||
|
token, err := a.CreateToken(u)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.NotEmpty(t, token.Value)
|
||||||
|
tokens = append(tokens, token.Value)
|
||||||
|
|
||||||
|
// Manually modify expiry date to avoid sorting issues (this is a hack)
|
||||||
|
_, err = a.db.Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value)
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = a.AuthenticateToken(tokens[0])
|
||||||
|
require.Equal(t, ErrUnauthenticated, err)
|
||||||
|
|
||||||
|
_, err = a.AuthenticateToken(tokens[1])
|
||||||
|
require.Equal(t, ErrUnauthenticated, err)
|
||||||
|
|
||||||
|
for i := 2; i < 12; i++ {
|
||||||
|
userWithToken, err := a.AuthenticateToken(tokens[i])
|
||||||
|
require.Nil(t, err, "token[%d]=%s failed", i, tokens[i])
|
||||||
|
require.Equal(t, "ben", userWithToken.Name)
|
||||||
|
require.Equal(t, tokens[i], userWithToken.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
rows, err := a.db.Query(`SELECT COUNT(*) FROM user_token`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, rows.Next())
|
||||||
|
require.Nil(t, rows.Scan(&count))
|
||||||
|
require.Equal(t, 10, count)
|
||||||
|
}
|
||||||
|
|
||||||
func TestManager_EnqueueStats(t *testing.T) {
|
func TestManager_EnqueueStats(t *testing.T) {
|
||||||
a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond)
|
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
|
|
||||||
|
@ -400,7 +443,7 @@ func TestManager_EnqueueStats(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_ChangeSettings(t *testing.T) {
|
func TestManager_ChangeSettings(t *testing.T) {
|
||||||
a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond)
|
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
|
|
||||||
|
@ -482,7 +525,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Create manager to trigger migration
|
// Create manager to trigger migration
|
||||||
a := newTestManagerFromFile(t, filename, PermissionDenyAll, userTokenExpiryDuration, userStatsQueueWriterInterval)
|
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, userStatsQueueWriterInterval)
|
||||||
checkSchemaVersion(t, a.db)
|
checkSchemaVersion(t, a.db)
|
||||||
|
|
||||||
users, err := a.Users()
|
users, err := a.Users()
|
||||||
|
@ -530,11 +573,11 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
|
func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
|
||||||
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval)
|
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, userStatsQueueWriterInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestManagerFromFile(t *testing.T, filename string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager {
|
func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) *Manager {
|
||||||
a, err := newManager(filename, defaultAccess, tokenExpiryDuration, statsWriterInterval)
|
a, err := newManager(filename, startupQueries, defaultAccess, statsWriterInterval)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue