mirror of
				https://github.com/binwiederhier/ntfy.git
				synced 2025-10-31 13:02:24 +01:00 
			
		
		
		
	Prune excess tokens per user
This commit is contained in:
		
							parent
							
								
									60f1882bec
								
							
						
					
					
						commit
						7fa63c8e19
					
				
					 3 changed files with 108 additions and 29 deletions
				
			
		|  | @ -45,10 +45,6 @@ import ( | |||
| 		"account topic" sync mechanism | ||||
| 		purge accounts that were not logged into in X | ||||
| 		reset daily limits for users | ||||
| 		max token issue limit | ||||
| 		user db startup queries -> foreign keys | ||||
| 		UI | ||||
| 		- Feature flag for "reserve topic" feature | ||||
| 		Sync: | ||||
| 			- "mute" setting | ||||
| 			- figure out what settings are "web" or "phone" | ||||
|  |  | |||
|  | @ -15,16 +15,18 @@ import ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	tokenLength                  = 32 | ||||
| 	bcryptCost                   = 10 | ||||
| 	intentionalSlowDownHash      = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost | ||||
| 	userStatsQueueWriterInterval = 33 * time.Second | ||||
| 	userTokenExpiryDuration      = 72 * time.Hour | ||||
| 	tokenLength                  = 32 | ||||
| 	tokenExpiryDuration          = 72 * time.Hour // Extend tokens by this much | ||||
| 	tokenMaxCount                = 10             // Only keep this many tokens in the table per user | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	errNoTokenProvided    = errors.New("no token provided") | ||||
| 	errTopicOwnedByOthers = errors.New("topic owned by others") | ||||
| 	errNoRows             = errors.New("no rows found") | ||||
| ) | ||||
| 
 | ||||
| // Manager-related queries | ||||
|  | @ -139,7 +141,7 @@ const ( | |||
| 		ORDER BY a_user.topic | ||||
| 	` | ||||
| 	selectOtherAccessCountQuery = ` | ||||
| 		SELECT count(*) | ||||
| 		SELECT COUNT(*) | ||||
| 		FROM user_access | ||||
| 		WHERE (topic = ? OR ? LIKE topic) | ||||
| 		  AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?)) | ||||
|  | @ -148,10 +150,22 @@ const ( | |||
| 	deleteUserAccessQuery  = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)` | ||||
| 	deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?` | ||||
| 
 | ||||
| 	selectTokenCountQuery    = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)` | ||||
| 	insertTokenQuery         = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)` | ||||
| 	updateTokenExpiryQuery   = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` | ||||
| 	deleteTokenQuery         = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` | ||||
| 	deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?` | ||||
| 	deleteExcessTokensQuery  = ` | ||||
| 		DELETE FROM user_token | ||||
| 		WHERE (user_id, token) NOT IN ( | ||||
| 			SELECT user_id, token | ||||
| 			FROM user_token | ||||
| 			WHERE user_id = (SELECT id FROM user WHERE user = ?) | ||||
| 			ORDER BY expires DESC  | ||||
| 			LIMIT ? | ||||
| 		) | ||||
| ; | ||||
| 	` | ||||
| ) | ||||
| 
 | ||||
| // Schema management queries | ||||
|  | @ -182,22 +196,21 @@ const ( | |||
| // Manager is an implementation of Manager. It stores users and access control list | ||||
| // in a SQLite database. | ||||
| type Manager struct { | ||||
| 	db                  *sql.DB | ||||
| 	defaultAccess       Permission       // Default permission if no ACL matches | ||||
| 	statsQueue          map[string]*User // Username -> User, for "unimportant" user updates | ||||
| 	tokenExpiryInterval time.Duration    // Duration after which tokens expire, and by which tokens are extended | ||||
| 	mu                  sync.Mutex | ||||
| 	db            *sql.DB | ||||
| 	defaultAccess Permission       // Default permission if no ACL matches | ||||
| 	statsQueue    map[string]*User // Username -> User, for "unimportant" user updates | ||||
| 	mu            sync.Mutex | ||||
| } | ||||
| 
 | ||||
| var _ Auther = (*Manager)(nil) | ||||
| 
 | ||||
| // NewManager creates a new Manager instance | ||||
| func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) { | ||||
| 	return newManager(filename, startupQueries, defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval) | ||||
| 	return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval) | ||||
| } | ||||
| 
 | ||||
| // NewManager creates a new Manager instance | ||||
| func newManager(filename, startupQueries string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) { | ||||
| func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) { | ||||
| 	db, err := sql.Open("sqlite3", filename) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|  | @ -209,10 +222,9 @@ func newManager(filename, startupQueries string, defaultAccess Permission, token | |||
| 		return nil, err | ||||
| 	} | ||||
| 	manager := &Manager{ | ||||
| 		db:                  db, | ||||
| 		defaultAccess:       defaultAccess, | ||||
| 		statsQueue:          make(map[string]*User), | ||||
| 		tokenExpiryInterval: tokenExpiryDuration, | ||||
| 		db:            db, | ||||
| 		defaultAccess: defaultAccess, | ||||
| 		statsQueue:    make(map[string]*User), | ||||
| 	} | ||||
| 	go manager.userStatsQueueWriter(statsWriterInterval) | ||||
| 	return manager, nil | ||||
|  | @ -253,10 +265,38 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { | |||
| } | ||||
| 
 | ||||
| // CreateToken generates a random token for the given user and returns it. The token expires | ||||
| // after a fixed duration unless ExtendToken is called. | ||||
| // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the | ||||
| // given user, if there are too many of them. | ||||
| func (a *Manager) CreateToken(user *User) (*Token, error) { | ||||
| 	token, expires := util.RandomString(tokenLength), time.Now().Add(userTokenExpiryDuration) | ||||
| 	if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil { | ||||
| 	token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration) | ||||
| 	tx, err := a.db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer tx.Rollback() | ||||
| 	if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	rows, err := tx.Query(selectTokenCountQuery, user.Name) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer rows.Close() | ||||
| 	if !rows.Next() { | ||||
| 		return nil, errNoRows | ||||
| 	} | ||||
| 	var tokenCount int | ||||
| 	if err := rows.Scan(&tokenCount); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if tokenCount >= tokenMaxCount { | ||||
| 		// This pruning logic is done in two queries for efficiency. The SELECT above is a lookup | ||||
| 		// on two indices, whereas the query below is a full table scan. | ||||
| 		if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 	if err := tx.Commit(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Token{ | ||||
|  | @ -270,7 +310,7 @@ func (a *Manager) ExtendToken(user *User) (*Token, error) { | |||
| 	if user.Token == "" { | ||||
| 		return nil, errNoTokenProvided | ||||
| 	} | ||||
| 	newExpires := time.Now().Add(userTokenExpiryDuration) | ||||
| 	newExpires := time.Now().Add(tokenExpiryDuration) | ||||
| 	if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -600,7 +640,7 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error { | |||
| 	} | ||||
| 	defer rows.Close() | ||||
| 	if !rows.Next() { | ||||
| 		return errors.New("no rows found") | ||||
| 		return errNoRows | ||||
| 	} | ||||
| 	var otherCount int | ||||
| 	if err := rows.Scan(&otherCount); err != nil { | ||||
|  |  | |||
|  | @ -369,8 +369,51 @@ func TestManager_Token_Extend(t *testing.T) { | |||
| 	require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix()) | ||||
| } | ||||
| 
 | ||||
| func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { | ||||
| 	a := newTestManager(t, PermissionDenyAll) | ||||
| 	require.Nil(t, a.AddUser("ben", "ben", RoleUser)) | ||||
| 
 | ||||
| 	// Try to extend token for user without token | ||||
| 	u, err := a.User("ben") | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	// Tokens | ||||
| 	baseTime := time.Now().Add(24 * time.Hour) | ||||
| 	tokens := make([]string, 0) | ||||
| 	for i := 0; i < 12; i++ { | ||||
| 		token, err := a.CreateToken(u) | ||||
| 		require.Nil(t, err) | ||||
| 		require.NotEmpty(t, token.Value) | ||||
| 		tokens = append(tokens, token.Value) | ||||
| 
 | ||||
| 		// Manually modify expiry date to avoid sorting issues (this is a hack) | ||||
| 		_, err = a.db.Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value) | ||||
| 		require.Nil(t, err) | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = a.AuthenticateToken(tokens[0]) | ||||
| 	require.Equal(t, ErrUnauthenticated, err) | ||||
| 
 | ||||
| 	_, err = a.AuthenticateToken(tokens[1]) | ||||
| 	require.Equal(t, ErrUnauthenticated, err) | ||||
| 
 | ||||
| 	for i := 2; i < 12; i++ { | ||||
| 		userWithToken, err := a.AuthenticateToken(tokens[i]) | ||||
| 		require.Nil(t, err, "token[%d]=%s failed", i, tokens[i]) | ||||
| 		require.Equal(t, "ben", userWithToken.Name) | ||||
| 		require.Equal(t, tokens[i], userWithToken.Token) | ||||
| 	} | ||||
| 
 | ||||
| 	var count int | ||||
| 	rows, err := a.db.Query(`SELECT COUNT(*) FROM user_token`) | ||||
| 	require.Nil(t, err) | ||||
| 	require.True(t, rows.Next()) | ||||
| 	require.Nil(t, rows.Scan(&count)) | ||||
| 	require.Equal(t, 10, count) | ||||
| } | ||||
| 
 | ||||
| func TestManager_EnqueueStats(t *testing.T) { | ||||
| 	a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond) | ||||
| 	a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) | ||||
| 	require.Nil(t, err) | ||||
| 	require.Nil(t, a.AddUser("ben", "ben", RoleUser)) | ||||
| 
 | ||||
|  | @ -400,7 +443,7 @@ func TestManager_EnqueueStats(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestManager_ChangeSettings(t *testing.T) { | ||||
| 	a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond) | ||||
| 	a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) | ||||
| 	require.Nil(t, err) | ||||
| 	require.Nil(t, a.AddUser("ben", "ben", RoleUser)) | ||||
| 
 | ||||
|  | @ -482,7 +525,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) { | |||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	// Create manager to trigger migration | ||||
| 	a := newTestManagerFromFile(t, filename, PermissionDenyAll, userTokenExpiryDuration, userStatsQueueWriterInterval) | ||||
| 	a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, userStatsQueueWriterInterval) | ||||
| 	checkSchemaVersion(t, a.db) | ||||
| 
 | ||||
| 	users, err := a.Users() | ||||
|  | @ -530,11 +573,11 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) { | |||
| } | ||||
| 
 | ||||
| func newTestManager(t *testing.T, defaultAccess Permission) *Manager { | ||||
| 	return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval) | ||||
| 	return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, userStatsQueueWriterInterval) | ||||
| } | ||||
| 
 | ||||
| func newTestManagerFromFile(t *testing.T, filename string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager { | ||||
| 	a, err := newManager(filename, defaultAccess, tokenExpiryDuration, statsWriterInterval) | ||||
| func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) *Manager { | ||||
| 	a, err := newManager(filename, startupQueries, defaultAccess, statsWriterInterval) | ||||
| 	require.Nil(t, err) | ||||
| 	return a | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 binwiederhier
						binwiederhier