mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-22 19:33:27 +01:00
Prune excess tokens per user
This commit is contained in:
parent
60f1882bec
commit
7fa63c8e19
3 changed files with 108 additions and 29 deletions
|
@ -45,10 +45,6 @@ import (
|
|||
"account topic" sync mechanism
|
||||
purge accounts that were not logged into in X
|
||||
reset daily limits for users
|
||||
max token issue limit
|
||||
user db startup queries -> foreign keys
|
||||
UI
|
||||
- Feature flag for "reserve topic" feature
|
||||
Sync:
|
||||
- "mute" setting
|
||||
- figure out what settings are "web" or "phone"
|
||||
|
|
|
@ -15,16 +15,18 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
tokenLength = 32
|
||||
bcryptCost = 10
|
||||
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
|
||||
userStatsQueueWriterInterval = 33 * time.Second
|
||||
userTokenExpiryDuration = 72 * time.Hour
|
||||
tokenLength = 32
|
||||
tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
|
||||
tokenMaxCount = 10 // Only keep this many tokens in the table per user
|
||||
)
|
||||
|
||||
var (
|
||||
errNoTokenProvided = errors.New("no token provided")
|
||||
errTopicOwnedByOthers = errors.New("topic owned by others")
|
||||
errNoRows = errors.New("no rows found")
|
||||
)
|
||||
|
||||
// Manager-related queries
|
||||
|
@ -139,7 +141,7 @@ const (
|
|||
ORDER BY a_user.topic
|
||||
`
|
||||
selectOtherAccessCountQuery = `
|
||||
SELECT count(*)
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE (topic = ? OR ? LIKE topic)
|
||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
||||
|
@ -148,10 +150,22 @@ const (
|
|||
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
|
||||
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
|
||||
|
||||
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)`
|
||||
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
|
||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
|
||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
|
||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
|
||||
deleteExcessTokensQuery = `
|
||||
DELETE FROM user_token
|
||||
WHERE (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
FROM user_token
|
||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||
ORDER BY expires DESC
|
||||
LIMIT ?
|
||||
)
|
||||
;
|
||||
`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
|
@ -185,7 +199,6 @@ type Manager struct {
|
|||
db *sql.DB
|
||||
defaultAccess Permission // Default permission if no ACL matches
|
||||
statsQueue map[string]*User // Username -> User, for "unimportant" user updates
|
||||
tokenExpiryInterval time.Duration // Duration after which tokens expire, and by which tokens are extended
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -193,11 +206,11 @@ var _ Auther = (*Manager)(nil)
|
|||
|
||||
// NewManager creates a new Manager instance
|
||||
func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
|
||||
return newManager(filename, startupQueries, defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval)
|
||||
return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
|
||||
}
|
||||
|
||||
// NewManager creates a new Manager instance
|
||||
func newManager(filename, startupQueries string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) {
|
||||
func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -212,7 +225,6 @@ func newManager(filename, startupQueries string, defaultAccess Permission, token
|
|||
db: db,
|
||||
defaultAccess: defaultAccess,
|
||||
statsQueue: make(map[string]*User),
|
||||
tokenExpiryInterval: tokenExpiryDuration,
|
||||
}
|
||||
go manager.userStatsQueueWriter(statsWriterInterval)
|
||||
return manager, nil
|
||||
|
@ -253,10 +265,38 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
|||
}
|
||||
|
||||
// CreateToken generates a random token for the given user and returns it. The token expires
|
||||
// after a fixed duration unless ExtendToken is called.
|
||||
// after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
|
||||
// given user, if there are too many of them.
|
||||
func (a *Manager) CreateToken(user *User) (*Token, error) {
|
||||
token, expires := util.RandomString(tokenLength), time.Now().Add(userTokenExpiryDuration)
|
||||
if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
|
||||
token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration)
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := tx.Query(selectTokenCountQuery, user.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return nil, errNoRows
|
||||
}
|
||||
var tokenCount int
|
||||
if err := rows.Scan(&tokenCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tokenCount >= tokenMaxCount {
|
||||
// This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
|
||||
// on two indices, whereas the query below is a full table scan.
|
||||
if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{
|
||||
|
@ -270,7 +310,7 @@ func (a *Manager) ExtendToken(user *User) (*Token, error) {
|
|||
if user.Token == "" {
|
||||
return nil, errNoTokenProvided
|
||||
}
|
||||
newExpires := time.Now().Add(userTokenExpiryDuration)
|
||||
newExpires := time.Now().Add(tokenExpiryDuration)
|
||||
if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -600,7 +640,7 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error {
|
|||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return errors.New("no rows found")
|
||||
return errNoRows
|
||||
}
|
||||
var otherCount int
|
||||
if err := rows.Scan(&otherCount); err != nil {
|
||||
|
|
|
@ -369,8 +369,51 @@ func TestManager_Token_Extend(t *testing.T) {
|
|||
require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix())
|
||||
}
|
||||
|
||||
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
// Try to extend token for user without token
|
||||
u, err := a.User("ben")
|
||||
require.Nil(t, err)
|
||||
|
||||
// Tokens
|
||||
baseTime := time.Now().Add(24 * time.Hour)
|
||||
tokens := make([]string, 0)
|
||||
for i := 0; i < 12; i++ {
|
||||
token, err := a.CreateToken(u)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token.Value)
|
||||
tokens = append(tokens, token.Value)
|
||||
|
||||
// Manually modify expiry date to avoid sorting issues (this is a hack)
|
||||
_, err = a.db.Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
_, err = a.AuthenticateToken(tokens[0])
|
||||
require.Equal(t, ErrUnauthenticated, err)
|
||||
|
||||
_, err = a.AuthenticateToken(tokens[1])
|
||||
require.Equal(t, ErrUnauthenticated, err)
|
||||
|
||||
for i := 2; i < 12; i++ {
|
||||
userWithToken, err := a.AuthenticateToken(tokens[i])
|
||||
require.Nil(t, err, "token[%d]=%s failed", i, tokens[i])
|
||||
require.Equal(t, "ben", userWithToken.Name)
|
||||
require.Equal(t, tokens[i], userWithToken.Token)
|
||||
}
|
||||
|
||||
var count int
|
||||
rows, err := a.db.Query(`SELECT COUNT(*) FROM user_token`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
require.Nil(t, rows.Scan(&count))
|
||||
require.Equal(t, 10, count)
|
||||
}
|
||||
|
||||
func TestManager_EnqueueStats(t *testing.T) {
|
||||
a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond)
|
||||
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
|
@ -400,7 +443,7 @@ func TestManager_EnqueueStats(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestManager_ChangeSettings(t *testing.T) {
|
||||
a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond)
|
||||
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
|
@ -482,7 +525,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
|
||||
// Create manager to trigger migration
|
||||
a := newTestManagerFromFile(t, filename, PermissionDenyAll, userTokenExpiryDuration, userStatsQueueWriterInterval)
|
||||
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, userStatsQueueWriterInterval)
|
||||
checkSchemaVersion(t, a.db)
|
||||
|
||||
users, err := a.Users()
|
||||
|
@ -530,11 +573,11 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
|||
}
|
||||
|
||||
func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
|
||||
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval)
|
||||
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, userStatsQueueWriterInterval)
|
||||
}
|
||||
|
||||
func newTestManagerFromFile(t *testing.T, filename string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager {
|
||||
a, err := newManager(filename, defaultAccess, tokenExpiryDuration, statsWriterInterval)
|
||||
func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) *Manager {
|
||||
a, err := newManager(filename, startupQueries, defaultAccess, statsWriterInterval)
|
||||
require.Nil(t, err)
|
||||
return a
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue