mirror of
				https://github.com/binwiederhier/ntfy.git
				synced 2025-10-31 13:02:24 +01:00 
			
		
		
		
	Add "auth-tokens"
This commit is contained in:
		
							parent
							
								
									149c13e9d8
								
							
						
					
					
						commit
						23ec7702fc
					
				
					 10 changed files with 263 additions and 88 deletions
				
			
		
							
								
								
									
										165
									
								
								user/manager.go
									
										
									
									
									
								
							
							
						
						
									
										165
									
								
								user/manager.go
									
										
									
									
									
								
							|  | @ -111,9 +111,11 @@ const ( | |||
| 			last_access INT NOT NULL, | ||||
| 			last_origin TEXT NOT NULL, | ||||
| 			expires INT NOT NULL, | ||||
| 			provisioned INT NOT NULL, | ||||
| 			PRIMARY KEY (user_id, token), | ||||
| 			FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE | ||||
| 		); | ||||
| 		CREATE UNIQUE INDEX idx_user_token ON user_token (token); | ||||
| 		CREATE TABLE IF NOT EXISTS user_phone ( | ||||
| 			user_id TEXT NOT NULL, | ||||
| 			phone_number TEXT NOT NULL, | ||||
|  | @ -181,16 +183,17 @@ const ( | |||
| 				ELSE 2 | ||||
| 			END, user | ||||
| 	` | ||||
| 	selectUserCountQuery         = `SELECT COUNT(*) FROM user` | ||||
| 	updateUserPassQuery          = `UPDATE user SET pass = ? WHERE user = ?` | ||||
| 	updateUserRoleQuery          = `UPDATE user SET role = ? WHERE user = ?` | ||||
| 	updateUserProvisionedQuery   = `UPDATE user SET provisioned = ? WHERE user = ?` | ||||
| 	updateUserPrefsQuery         = `UPDATE user SET prefs = ? WHERE id = ?` | ||||
| 	updateUserStatsQuery         = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?` | ||||
| 	updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0` | ||||
| 	updateUserDeletedQuery       = `UPDATE user SET deleted = ? WHERE id = ?` | ||||
| 	deleteUsersMarkedQuery       = `DELETE FROM user WHERE deleted < ?` | ||||
| 	deleteUserQuery              = `DELETE FROM user WHERE user = ?` | ||||
| 	selectUserCountQuery          = `SELECT COUNT(*) FROM user` | ||||
| 	selectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?` | ||||
| 	updateUserPassQuery           = `UPDATE user SET pass = ? WHERE user = ?` | ||||
| 	updateUserRoleQuery           = `UPDATE user SET role = ? WHERE user = ?` | ||||
| 	updateUserProvisionedQuery    = `UPDATE user SET provisioned = ? WHERE user = ?` | ||||
| 	updateUserPrefsQuery          = `UPDATE user SET prefs = ? WHERE id = ?` | ||||
| 	updateUserStatsQuery          = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?` | ||||
| 	updateUserStatsResetAllQuery  = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0` | ||||
| 	updateUserDeletedQuery        = `UPDATE user SET deleted = ? WHERE id = ?` | ||||
| 	deleteUsersMarkedQuery        = `DELETE FROM user WHERE deleted < ?` | ||||
| 	deleteUserQuery               = `DELETE FROM user WHERE user = ?` | ||||
| 
 | ||||
| 	upsertUserAccessQuery = ` | ||||
| 		INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned) | ||||
|  | @ -220,7 +223,7 @@ const ( | |||
| 	selectUserReservationsCountQuery = ` | ||||
| 		SELECT COUNT(*) | ||||
| 		FROM user_access | ||||
| 		WHERE user_id = owner_user_id  | ||||
| 		WHERE user_id = owner_user_id | ||||
| 		  AND owner_user_id = (SELECT id FROM user WHERE user = ?) | ||||
| 	` | ||||
| 	selectUserReservationsOwnerQuery = ` | ||||
|  | @ -255,17 +258,23 @@ const ( | |||
| 	   	  AND topic = ? | ||||
|   	` | ||||
| 
 | ||||
| 	selectTokenCountQuery      = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` | ||||
| 	selectTokensQuery          = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?` | ||||
| 	selectTokenQuery           = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?` | ||||
| 	insertTokenQuery           = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)` | ||||
| 	updateTokenExpiryQuery     = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?` | ||||
| 	updateTokenLabelQuery      = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?` | ||||
| 	updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` | ||||
| 	deleteTokenQuery           = `DELETE FROM user_token WHERE user_id = ? AND token = ?` | ||||
| 	deleteAllTokenQuery        = `DELETE FROM user_token WHERE user_id = ?` | ||||
| 	deleteExpiredTokensQuery   = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` | ||||
| 	deleteExcessTokensQuery    = ` | ||||
| 	selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` | ||||
| 	selectTokensQuery     = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?` | ||||
| 	selectTokenQuery      = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?` | ||||
| 	upsertTokenQuery      = ` | ||||
| 		INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned) | ||||
| 		VALUES (?, ?, ?, ?, ?, ?, ?) | ||||
| 		ON CONFLICT (user_id, token) | ||||
| 		DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned; | ||||
| 	` | ||||
| 	updateTokenExpiryQuery       = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?` | ||||
| 	updateTokenLabelQuery        = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?` | ||||
| 	updateTokenLastAccessQuery   = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` | ||||
| 	deleteTokenQuery             = `DELETE FROM user_token WHERE user_id = ? AND token = ?` | ||||
| 	deleteAllTokenQuery          = `DELETE FROM user_token WHERE user_id = ?` | ||||
| 	deleteTokensProvisionedQuery = `DELETE FROM user_token WHERE provisioned = 1` | ||||
| 	deleteExpiredTokensQuery     = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` | ||||
| 	deleteExcessTokensQuery      = ` | ||||
| 		DELETE FROM user_token | ||||
| 		WHERE user_id = ? | ||||
| 		  AND (user_id, token) NOT IN ( | ||||
|  | @ -470,7 +479,7 @@ const ( | |||
| 		    role, | ||||
| 		    prefs, | ||||
| 		    sync_topic, | ||||
| 		    0, | ||||
| 		    0, -- provisioned | ||||
| 		    stats_messages, | ||||
| 		    stats_emails, | ||||
| 		    stats_calls, | ||||
|  | @ -480,7 +489,8 @@ const ( | |||
| 		    stripe_subscription_interval, | ||||
| 		    stripe_subscription_paid_until, | ||||
| 		    stripe_subscription_cancel_at, | ||||
| 		    created, deleted | ||||
| 		    created, | ||||
| 		    deleted | ||||
| 		FROM user_old; | ||||
| 		DROP TABLE user_old; | ||||
| 
 | ||||
|  | @ -500,10 +510,27 @@ const ( | |||
| 		INSERT INTO user_access SELECT *, 0 FROM user_access_old; | ||||
| 		DROP TABLE user_access_old; | ||||
| 
 | ||||
| 		-- Alter user_token table: Add provisioned column | ||||
| 		ALTER TABLE user_token RENAME TO user_token_old; | ||||
| 		CREATE TABLE IF NOT EXISTS user_token ( | ||||
| 			user_id TEXT NOT NULL, | ||||
| 			token TEXT NOT NULL, | ||||
| 			label TEXT NOT NULL, | ||||
| 			last_access INT NOT NULL, | ||||
| 			last_origin TEXT NOT NULL, | ||||
| 			expires INT NOT NULL, | ||||
| 			provisioned INT NOT NULL, | ||||
| 			PRIMARY KEY (user_id, token), | ||||
| 			FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE | ||||
| 		); | ||||
| 		INSERT INTO user_token SELECT *, 0 FROM user_token_old; | ||||
| 		DROP TABLE user_token_old; | ||||
| 
 | ||||
| 		-- Recreate indices | ||||
| 		CREATE UNIQUE INDEX idx_user ON user (user); | ||||
| 		CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); | ||||
| 		CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); | ||||
| 		CREATE UNIQUE INDEX idx_user_token ON user_token (token); | ||||
| 
 | ||||
| 		-- Re-enable foreign keys | ||||
| 		PRAGMA foreign_keys=on; | ||||
|  | @ -537,7 +564,8 @@ type Config struct { | |||
| 	DefaultAccess       Permission          // Default permission if no ACL matches | ||||
| 	ProvisionEnabled    bool                // Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands | ||||
| 	Users               []*User             // Predefined users to create on startup | ||||
| 	Access              map[string][]*Grant // Predefined access grants to create on startup | ||||
| 	Access              map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant) | ||||
| 	Tokens              map[string][]*Token // Predefined users to create on startup (username -> []*Token) | ||||
| 	QueueWriterInterval time.Duration       // Interval for the async queue writer to flush stats and token updates to the database | ||||
| 	BcryptCost          int                 // Cost of generated passwords; lowering makes testing faster | ||||
| } | ||||
|  | @ -623,15 +651,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { | |||
| // CreateToken generates a random token for the given user and returns it. The token expires | ||||
| // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the | ||||
| // given user, if there are too many of them. | ||||
| func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) { | ||||
| 	token := util.RandomLowerStringPrefix(tokenPrefix, tokenLength) // Lowercase only to support "<topic>+<token>@<domain>" email addresses | ||||
| 	tx, err := a.db.Begin() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer tx.Rollback() | ||||
| func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { | ||||
| 	return queryTx(a.db, func(tx *sql.Tx) (*Token, error) { | ||||
| 		return a.createTokenTx(tx, userID, GenerateToken(), label, expires, origin, provisioned) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { | ||||
| 	access := time.Now() | ||||
| 	if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil { | ||||
| 	if _, err := tx.Exec(upsertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix(), provisioned); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	rows, err := tx.Query(selectTokenCountQuery, userID) | ||||
|  | @ -653,15 +681,13 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time, origin ne | |||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 	if err := tx.Commit(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Token{ | ||||
| 		Value:      token, | ||||
| 		Label:      label, | ||||
| 		LastAccess: access, | ||||
| 		LastOrigin: origin, | ||||
| 		Expires:    expires, | ||||
| 		Value:       token, | ||||
| 		Label:       label, | ||||
| 		LastAccess:  access, | ||||
| 		LastOrigin:  origin, | ||||
| 		Expires:     expires, | ||||
| 		Provisioned: provisioned, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -698,10 +724,11 @@ func (a *Manager) Token(userID, token string) (*Token, error) { | |||
| func (a *Manager) readToken(rows *sql.Rows) (*Token, error) { | ||||
| 	var token, label, lastOrigin string | ||||
| 	var lastAccess, expires int64 | ||||
| 	var provisioned bool | ||||
| 	if !rows.Next() { | ||||
| 		return nil, ErrTokenNotFound | ||||
| 	} | ||||
| 	if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil { | ||||
| 	if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil { | ||||
| 		return nil, err | ||||
| 	} else if err := rows.Err(); err != nil { | ||||
| 		return nil, err | ||||
|  | @ -711,11 +738,12 @@ func (a *Manager) readToken(rows *sql.Rows) (*Token, error) { | |||
| 		lastOriginIP = netip.IPv4Unspecified() | ||||
| 	} | ||||
| 	return &Token{ | ||||
| 		Value:      token, | ||||
| 		Label:      label, | ||||
| 		LastAccess: time.Unix(lastAccess, 0), | ||||
| 		LastOrigin: lastOriginIP, | ||||
| 		Expires:    time.Unix(expires, 0), | ||||
| 		Value:       token, | ||||
| 		Label:       label, | ||||
| 		LastAccess:  time.Unix(lastAccess, 0), | ||||
| 		LastOrigin:  lastOriginIP, | ||||
| 		Expires:     time.Unix(expires, 0), | ||||
| 		Provisioned: provisioned, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -774,7 +802,7 @@ func (a *Manager) PhoneNumbers(userID string) ([]string, error) { | |||
| 	phoneNumbers := make([]string, 0) | ||||
| 	for { | ||||
| 		phoneNumber, err := a.readPhoneNumber(rows) | ||||
| 		if err == ErrPhoneNumberNotFound { | ||||
| 		if errors.Is(err, ErrPhoneNumberNotFound) { | ||||
| 			break | ||||
| 		} else if err != nil { | ||||
| 			return nil, err | ||||
|  | @ -1757,6 +1785,28 @@ func (a *Manager) maybeProvisionUsersAndAccess() error { | |||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		// Remove and (re-)add provisioned tokens | ||||
| 		if _, err := tx.Exec(deleteTokensProvisionedQuery); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		for username, tokens := range a.config.Tokens { | ||||
| 			_, exists := util.Find(a.config.Users, func(u *User) bool { | ||||
| 				return u.Name == username | ||||
| 			}) | ||||
| 			if !exists && username != Everyone { | ||||
| 				return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username) | ||||
| 			} | ||||
| 			var userID string | ||||
| 			row := tx.QueryRow(selectUserIDFromUsernameQuery, username) | ||||
| 			if err := row.Scan(&userID); err != nil { | ||||
| 				return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username) | ||||
| 			} | ||||
| 			for _, token := range tokens { | ||||
| 				if _, err = a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
| } | ||||
|  | @ -1974,3 +2024,22 @@ func execTx(db *sql.DB, f func(tx *sql.Tx) error) error { | |||
| 	} | ||||
| 	return tx.Commit() | ||||
| } | ||||
| 
 | ||||
| // queryTx executes a function in a transaction and returns the result. If the function | ||||
| // returns an error, the transaction is rolled back. | ||||
| func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) { | ||||
| 	tx, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		var zero T | ||||
| 		return zero, err | ||||
| 	} | ||||
| 	defer tx.Rollback() | ||||
| 	t, err := f(tx) | ||||
| 	if err != nil { | ||||
| 		return t, err | ||||
| 	} | ||||
| 	if err := tx.Commit(); err != nil { | ||||
| 		return t, err | ||||
| 	} | ||||
| 	return t, nil | ||||
| } | ||||
|  |  | |||
|  | @ -194,7 +194,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) { | |||
| 	require.Nil(t, err) | ||||
| 	require.False(t, u.Deleted) | ||||
| 
 | ||||
| 	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified()) | ||||
| 	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false) | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	u, err = a.Authenticate("user", "pass") | ||||
|  | @ -241,7 +241,7 @@ func TestManager_CreateToken_Only_Lower(t *testing.T) { | |||
| 	u, err := a.User("user") | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified()) | ||||
| 	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false) | ||||
| 	require.Nil(t, err) | ||||
| 	require.Equal(t, token.Value, strings.ToLower(token.Value)) | ||||
| } | ||||
|  | @ -523,7 +523,7 @@ func TestManager_Token_Valid(t *testing.T) { | |||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	// Create token for user | ||||
| 	token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) | ||||
| 	token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false) | ||||
| 	require.Nil(t, err) | ||||
| 	require.NotEmpty(t, token.Value) | ||||
| 	require.Equal(t, "some label", token.Label) | ||||
|  | @ -586,12 +586,12 @@ func TestManager_Token_Expire(t *testing.T) { | |||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	// Create tokens for user | ||||
| 	token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) | ||||
| 	token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false) | ||||
| 	require.Nil(t, err) | ||||
| 	require.NotEmpty(t, token1.Value) | ||||
| 	require.True(t, time.Now().Add(71*time.Hour).Unix() < token1.Expires.Unix()) | ||||
| 
 | ||||
| 	token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) | ||||
| 	token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false) | ||||
| 	require.Nil(t, err) | ||||
| 	require.NotEmpty(t, token2.Value) | ||||
| 	require.NotEqual(t, token1.Value, token2.Value) | ||||
|  | @ -638,7 +638,7 @@ func TestManager_Token_Extend(t *testing.T) { | |||
| 	require.Equal(t, errNoTokenProvided, err) | ||||
| 
 | ||||
| 	// Create token for user | ||||
| 	token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) | ||||
| 	token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false) | ||||
| 	require.Nil(t, err) | ||||
| 	require.NotEmpty(t, token.Value) | ||||
| 
 | ||||
|  | @ -668,12 +668,12 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { | |||
| 
 | ||||
| 	// Create 2 tokens for phil | ||||
| 	philTokens := make([]string, 0) | ||||
| 	token, err := a.CreateToken(phil.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) | ||||
| 	token, err := a.CreateToken(phil.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false) | ||||
| 	require.Nil(t, err) | ||||
| 	require.NotEmpty(t, token.Value) | ||||
| 	philTokens = append(philTokens, token.Value) | ||||
| 
 | ||||
| 	token, err = a.CreateToken(phil.ID, "", time.Unix(0, 0), netip.IPv4Unspecified()) | ||||
| 	token, err = a.CreateToken(phil.ID, "", time.Unix(0, 0), netip.IPv4Unspecified(), false) | ||||
| 	require.Nil(t, err) | ||||
| 	require.NotEmpty(t, token.Value) | ||||
| 	philTokens = append(philTokens, token.Value) | ||||
|  | @ -682,7 +682,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { | |||
| 	baseTime := time.Now().Add(24 * time.Hour) | ||||
| 	benTokens := make([]string, 0) | ||||
| 	for i := 0; i < 62; i++ { // | ||||
| 		token, err := a.CreateToken(ben.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) | ||||
| 		token, err := a.CreateToken(ben.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false) | ||||
| 		require.Nil(t, err) | ||||
| 		require.NotEmpty(t, token.Value) | ||||
| 		benTokens = append(benTokens, token.Value) | ||||
|  | @ -795,7 +795,7 @@ func TestManager_EnqueueTokenUpdate(t *testing.T) { | |||
| 	u, err := a.User("ben") | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified()) | ||||
| 	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false) | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	// Queue token update | ||||
|  | @ -1112,6 +1112,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) { | |||
| 				{TopicPattern: "secret", Permission: PermissionRead}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		Tokens: map[string][]*Token{ | ||||
| 			"philuser": { | ||||
| 				{Value: "tk_op56p8lz5bf3cxkz9je99v9oc37lo", Label: "Alerts token"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	a, err := NewManager(conf) | ||||
| 	require.Nil(t, err) | ||||
|  | @ -1123,24 +1128,29 @@ func TestManager_WithProvisionedUsers(t *testing.T) { | |||
| 	users, err := a.Users() | ||||
| 	require.Nil(t, err) | ||||
| 	require.Len(t, users, 4) | ||||
| 
 | ||||
| 	require.Equal(t, "philadmin", users[0].Name) | ||||
| 	require.Equal(t, RoleAdmin, users[0].Role) | ||||
| 
 | ||||
| 	require.Equal(t, "philmanual", users[1].Name) | ||||
| 	require.Equal(t, RoleUser, users[1].Role) | ||||
| 	require.Equal(t, "philuser", users[2].Name) | ||||
| 	require.Equal(t, RoleUser, users[2].Role) | ||||
| 	require.Equal(t, "*", users[3].Name) | ||||
| 	provisionedUserID := users[2].ID // "philuser" is the provisioned user | ||||
| 
 | ||||
| 	grants, err := a.Grants("philuser") | ||||
| 	require.Nil(t, err) | ||||
| 	require.Equal(t, "philuser", users[2].Name) | ||||
| 	require.Equal(t, RoleUser, users[2].Role) | ||||
| 	require.Equal(t, 2, len(grants)) | ||||
| 	require.Equal(t, "secret", grants[0].TopicPattern) | ||||
| 	require.Equal(t, PermissionRead, grants[0].Permission) | ||||
| 	require.Equal(t, "stats", grants[1].TopicPattern) | ||||
| 	require.Equal(t, PermissionReadWrite, grants[1].Permission) | ||||
| 
 | ||||
| 	require.Equal(t, "*", users[3].Name) | ||||
| 	tokens, err := a.Tokens(provisionedUserID) | ||||
| 	require.Nil(t, err) | ||||
| 	require.Equal(t, 1, len(tokens)) | ||||
| 	require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc37lo", tokens[0].Value) | ||||
| 	require.Equal(t, "Alerts token", tokens[0].Label) | ||||
| 	require.True(t, tokens[0].Provisioned) | ||||
| 
 | ||||
| 	// Re-open the DB (second app start) | ||||
| 	require.Nil(t, a.db.Close()) | ||||
|  | @ -1153,6 +1163,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) { | |||
| 			{TopicPattern: "secret12", Permission: PermissionRead}, | ||||
| 		}, | ||||
| 	} | ||||
| 	conf.Tokens = map[string][]*Token{ | ||||
| 		"philuser": { | ||||
| 			{Value: "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", Label: "Alerts token updated"}, | ||||
| 		}, | ||||
| 	} | ||||
| 	a, err = NewManager(conf) | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
|  | @ -1160,30 +1175,36 @@ func TestManager_WithProvisionedUsers(t *testing.T) { | |||
| 	users, err = a.Users() | ||||
| 	require.Nil(t, err) | ||||
| 	require.Len(t, users, 3) | ||||
| 
 | ||||
| 	require.Equal(t, "philmanual", users[0].Name) | ||||
| 	require.Equal(t, "philuser", users[1].Name) | ||||
| 	require.Equal(t, RoleUser, users[1].Role) | ||||
| 	require.Equal(t, RoleUser, users[0].Role) | ||||
| 	require.Equal(t, "*", users[2].Name) | ||||
| 
 | ||||
| 	grants, err = a.Grants("philuser") | ||||
| 	require.Nil(t, err) | ||||
| 	require.Equal(t, "philuser", users[1].Name) | ||||
| 	require.Equal(t, RoleUser, users[1].Role) | ||||
| 	require.Equal(t, 2, len(grants)) | ||||
| 	require.Equal(t, "secret12", grants[0].TopicPattern) | ||||
| 	require.Equal(t, PermissionRead, grants[0].Permission) | ||||
| 	require.Equal(t, "stats12", grants[1].TopicPattern) | ||||
| 	require.Equal(t, PermissionReadWrite, grants[1].Permission) | ||||
| 
 | ||||
| 	require.Equal(t, "*", users[2].Name) | ||||
| 	tokens, err = a.Tokens(provisionedUserID) | ||||
| 	require.Nil(t, err) | ||||
| 	require.Equal(t, 1, len(tokens)) | ||||
| 	require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", tokens[0].Value) | ||||
| 	require.Equal(t, "Alerts token updated", tokens[0].Label) | ||||
| 	require.True(t, tokens[0].Provisioned) | ||||
| 
 | ||||
| 	// Re-open the DB again (third app start) | ||||
| 	require.Nil(t, a.db.Close()) | ||||
| 	conf.Users = []*User{} | ||||
| 	conf.Access = map[string][]*Grant{} | ||||
| 	conf.Tokens = map[string][]*Token{} | ||||
| 	a, err = NewManager(conf) | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	// Check that the provisioned users are there | ||||
| 	// Check that the provisioned users are all gone | ||||
| 	users, err = a.Users() | ||||
| 	require.Nil(t, err) | ||||
| 	require.Len(t, users, 2) | ||||
|  | @ -1191,6 +1212,14 @@ func TestManager_WithProvisionedUsers(t *testing.T) { | |||
| 	require.Equal(t, "philmanual", users[0].Name) | ||||
| 	require.Equal(t, RoleUser, users[0].Role) | ||||
| 	require.Equal(t, "*", users[1].Name) | ||||
| 
 | ||||
| 	grants, err = a.Grants("philuser") | ||||
| 	require.Nil(t, err) | ||||
| 	require.Equal(t, 0, len(grants)) | ||||
| 
 | ||||
| 	tokens, err = a.Tokens(provisionedUserID) | ||||
| 	require.Nil(t, err) | ||||
| 	require.Equal(t, 0, len(tokens)) | ||||
| } | ||||
| 
 | ||||
| func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { | ||||
|  |  | |||
|  | @ -4,6 +4,7 @@ import ( | |||
| 	"errors" | ||||
| 	"github.com/stripe/stripe-go/v74" | ||||
| 	"heckel.io/ntfy/v2/log" | ||||
| 	"heckel.io/ntfy/v2/util" | ||||
| 	"net/netip" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
|  | @ -59,11 +60,12 @@ type Auther interface { | |||
| 
 | ||||
| // Token represents a user token, including expiry date | ||||
| type Token struct { | ||||
| 	Value      string | ||||
| 	Label      string | ||||
| 	LastAccess time.Time | ||||
| 	LastOrigin netip.Addr | ||||
| 	Expires    time.Time | ||||
| 	Value       string | ||||
| 	Label       string | ||||
| 	LastAccess  time.Time | ||||
| 	LastOrigin  netip.Addr | ||||
| 	Expires     time.Time | ||||
| 	Provisioned bool | ||||
| } | ||||
| 
 | ||||
| // TokenUpdate holds information about the last access time and origin IP address of a token | ||||
|  | @ -247,6 +249,7 @@ var ( | |||
| 	allowedTopicRegex        = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)  // No '*' | ||||
| 	allowedTopicPatternRegex = regexp.MustCompile(`^[-_*A-Za-z0-9]{1,64}$`) // Adds '*' for wildcards! | ||||
| 	allowedTierRegex         = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) | ||||
| 	allowedTokenRegex        = regexp.MustCompile(`^tk_[-_A-Za-z0-9]{29}$`) // Must be tokenLength-len(tokenPrefix) | ||||
| ) | ||||
| 
 | ||||
| // AllowedRole returns true if the given role can be used for new users | ||||
|  | @ -282,6 +285,17 @@ func AllowedPasswordHash(hash string) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // AllowedToken returns true if the given token matches the naming convention | ||||
| func AllowedToken(token string) bool { | ||||
| 	return allowedTokenRegex.MatchString(token) | ||||
| } | ||||
| 
 | ||||
| // GenerateToken generates a new token with a prefix and a fixed length | ||||
| // Lowercase only to support "<topic>+<token>@<domain>" email addresses | ||||
| func GenerateToken() string { | ||||
| 	return util.RandomLowerStringPrefix(tokenPrefix, tokenLength) | ||||
| } | ||||
| 
 | ||||
| // Error constants used by the package | ||||
| var ( | ||||
| 	ErrUnauthenticated     = errors.New("unauthenticated") | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 binwiederhier
						binwiederhier