1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-11-04 23:10:35 +01:00

Add "auth-tokens"

This commit is contained in:
binwiederhier 2025-07-31 07:08:35 +02:00
parent 149c13e9d8
commit 23ec7702fc
10 changed files with 263 additions and 88 deletions

View file

@ -50,6 +50,7 @@ var flagsServe = append(
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-users", Aliases: []string{"auth_users"}, EnvVars: []string{"NTFY_AUTH_USERS"}, Usage: "pre-provisioned declarative users"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}),
@ -158,6 +159,7 @@ func execServe(c *cli.Context) error {
authDefaultAccess := c.String("auth-default-access")
authUsersRaw := c.StringSlice("auth-users")
authAccessRaw := c.StringSlice("auth-access")
authTokensRaw := c.StringSlice("auth-tokens")
attachmentCacheDir := c.String("attachment-cache-dir")
attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit")
attachmentFileSizeLimitStr := c.String("attachment-file-size-limit")
@ -361,6 +363,10 @@ func execServe(c *cli.Context) error {
if err != nil {
return err
}
authTokens, err := parseTokens(authUsers, authTokensRaw)
if err != nil {
return err
}
// Special case: Unset default
if listenHTTP == "-" {
@ -418,6 +424,7 @@ func execServe(c *cli.Context) error {
conf.AuthDefault = authDefault
conf.AuthUsers = authUsers
conf.AuthAccess = authAccess
conf.AuthTokens = authTokens
conf.AttachmentCacheDir = attachmentCacheDir
conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit
conf.AttachmentFileSizeLimit = attachmentFileSizeLimit
@ -532,7 +539,7 @@ func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
}
func parseUsers(usersRaw []string) ([]*user.User, error) {
provisionUsers := make([]*user.User, 0)
users := make([]*user.User, 0)
for _, userLine := range usersRaw {
parts := strings.Split(userLine, ":")
if len(parts) != 3 {
@ -548,19 +555,19 @@ func parseUsers(usersRaw []string) ([]*user.User, error) {
} else if !user.AllowedRole(role) {
return nil, fmt.Errorf("invalid auth-users: %s, role %s is not allowed, allowed roles are 'admin' or 'user'", userLine, role)
}
provisionUsers = append(provisionUsers, &user.User{
users = append(users, &user.User{
Name: username,
Hash: passwordHash,
Role: role,
Provisioned: true,
})
}
return provisionUsers, nil
return users, nil
}
func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[string][]*user.Grant, error) {
func parseAccess(users []*user.User, accessRaw []string) (map[string][]*user.Grant, error) {
access := make(map[string][]*user.Grant)
for _, accessLine := range provisionAccessRaw {
for _, accessLine := range accessRaw {
parts := strings.Split(accessLine, ":")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid auth-access: %s, expected format: 'user:topic:permission'", accessLine)
@ -569,7 +576,7 @@ func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[
if username == userEveryone {
username = user.Everyone
}
provisionUser, exists := util.Find(provisionUsers, func(u *user.User) bool {
u, exists := util.Find(users, func(u *user.User) bool {
return u.Name == username
})
if username != user.Everyone {
@ -577,7 +584,7 @@ func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[
return nil, fmt.Errorf("invalid auth-access: %s, user %s is not provisioned", accessLine, username)
} else if !user.AllowedUsername(username) {
return nil, fmt.Errorf("invalid auth-access: %s, username %s invalid", accessLine, username)
} else if provisionUser.Role != user.RoleUser {
} else if u.Role != user.RoleUser {
return nil, fmt.Errorf("invalid auth-access: %s, user %s is not a regular user, only regular users can have ACL entries", accessLine, username)
}
}
@ -601,6 +608,41 @@ func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[
return access, nil
}
func parseTokens(users []*user.User, tokensRaw []string) (map[string][]*user.Token, error) {
tokens := make(map[string][]*user.Token)
for _, tokenLine := range tokensRaw {
parts := strings.Split(tokenLine, ":")
if len(parts) < 2 || len(parts) > 3 {
return nil, fmt.Errorf("invalid auth-tokens: %s, expected format: 'user:token[:label]'", tokenLine)
}
username := strings.TrimSpace(parts[0])
_, exists := util.Find(users, func(u *user.User) bool {
return u.Name == username
})
if !exists {
return nil, fmt.Errorf("invalid auth-tokens: %s, user %s is not provisioned", tokenLine, username)
} else if !user.AllowedUsername(username) {
return nil, fmt.Errorf("invalid auth-tokens: %s, username %s invalid", tokenLine, username)
}
token := strings.TrimSpace(parts[1])
if !user.AllowedToken(token) {
return nil, fmt.Errorf("invalid auth-tokens: %s, token %s invalid, use 'ntfy token generate' to generate a random token", tokenLine, token)
}
var label string
if len(parts) > 2 {
label = parts[2]
}
if _, exists := tokens[username]; !exists {
tokens[username] = make([]*user.Token, 0)
}
tokens[username] = append(tokens[username], &user.Token{
Value: token,
Label: label,
})
}
return tokens, nil
}
func reloadLogLevel(inputSource altsrc.InputSourceContext) error {
newLevelStr, err := inputSource.String("log-level")
if err != nil {

View file

@ -72,6 +72,15 @@ Example:
This is a server-only command. It directly reads from user.db as defined in the server config
file server.yml. The command only works if 'auth-file' is properly defined.`,
},
{
Name: "generate",
Usage: "Generates a random token",
Action: execTokenGenerate,
Description: `Randomly generate a token to be used in provisioned tokens.
This command only generates the token value, but does not persist it anywhere.
The output can be used in the 'auth-tokens' config option.`,
},
},
Description: `Manage access tokens for individual users.
@ -112,12 +121,12 @@ func execTokenAdd(c *cli.Context) error {
return err
}
u, err := manager.User(username)
if err == user.ErrUserNotFound {
if errors.Is(err, user.ErrUserNotFound) {
return fmt.Errorf("user %s does not exist", username)
} else if err != nil {
return err
}
token, err := manager.CreateToken(u.ID, label, expires, netip.IPv4Unspecified())
token, err := manager.CreateToken(u.ID, label, expires, netip.IPv4Unspecified(), false)
if err != nil {
return err
}
@ -141,7 +150,7 @@ func execTokenDel(c *cli.Context) error {
return err
}
u, err := manager.User(username)
if err == user.ErrUserNotFound {
if errors.Is(err, user.ErrUserNotFound) {
return fmt.Errorf("user %s does not exist", username)
} else if err != nil {
return err
@ -165,7 +174,7 @@ func execTokenList(c *cli.Context) error {
var users []*user.User
if username != "" {
u, err := manager.User(username)
if err == user.ErrUserNotFound {
if errors.Is(err, user.ErrUserNotFound) {
return fmt.Errorf("user %s does not exist", username)
} else if err != nil {
return err
@ -191,7 +200,7 @@ func execTokenList(c *cli.Context) error {
usersWithTokens++
fmt.Fprintf(c.App.ErrWriter, "user %s\n", u.Name)
for _, t := range tokens {
var label, expires string
var label, expires, provisioned string
if t.Label != "" {
label = fmt.Sprintf(" (%s)", t.Label)
}
@ -200,7 +209,10 @@ func execTokenList(c *cli.Context) error {
} else {
expires = fmt.Sprintf("expires %s", t.Expires.Format(time.RFC822))
}
fmt.Fprintf(c.App.ErrWriter, "- %s%s, %s, accessed from %s at %s\n", t.Value, label, expires, t.LastOrigin.String(), t.LastAccess.Format(time.RFC822))
if t.Provisioned {
provisioned = " (server config)"
}
fmt.Fprintf(c.App.ErrWriter, "- %s%s, %s, accessed from %s at %s%s\n", t.Value, label, expires, t.LastOrigin.String(), t.LastAccess.Format(time.RFC822), provisioned)
}
}
if usersWithTokens == 0 {
@ -208,3 +220,8 @@ func execTokenList(c *cli.Context) error {
}
return nil
}
func execTokenGenerate(c *cli.Context) error {
fmt.Println(user.GenerateToken())
return nil
}

View file

@ -97,6 +97,7 @@ type Config struct {
AuthDefault user.Permission
AuthUsers []*user.User
AuthAccess map[string][]*user.Grant
AuthTokens map[string][]*user.Token
AuthBcryptCost int
AuthStatsQueueWriterInterval time.Duration
AttachmentCacheDir string

View file

@ -203,6 +203,7 @@ func New(conf *Config) (*Server, error) {
ProvisionEnabled: true, // Enable provisioning of users and access
Users: conf.AuthUsers,
Access: conf.AuthAccess,
Tokens: conf.AuthTokens,
BcryptCost: conf.AuthBcryptCost,
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
}

View file

@ -86,6 +86,8 @@
# Each entry is in the format "<username>:<password-hash>:<role>", e.g. "phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:user"
# - auth-access is a list of access control entries that are automatically created when the server starts.
# Each entry is in the format "<username>:<topic-pattern>:<access>", e.g. "phil:mytopic:rw" or "phil:phil-*:rw".
# - auth-tokens is a list of access tokens that are automatically created when the server starts.
# Each entry is in the format "<username>:<token>[:<label>]", e.g. "phil:tk_1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef:My token".
#
# Debian/RPM package users:
# Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package

View file

@ -234,7 +234,7 @@ func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request
"token_expires": expires,
}).
Debug("Creating token for user %s", u.Name)
token, err := s.userManager.CreateToken(u.ID, label, expires, v.IP())
token, err := s.userManager.CreateToken(u.ID, label, expires, v.IP(), false)
if err != nil {
return err
}

View file

@ -176,7 +176,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
u, _ := s.userManager.User("phil")
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified())
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified(), false)
rr := request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"sound": "juntos"},"ignored": true}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),

View file

@ -111,9 +111,11 @@ const (
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
@ -181,16 +183,17 @@ const (
ELSE 2
END, user
`
selectUserCountQuery = `SELECT COUNT(*) FROM user`
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
deleteUserQuery = `DELETE FROM user WHERE user = ?`
selectUserCountQuery = `SELECT COUNT(*) FROM user`
selectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
deleteUserQuery = `DELETE FROM user WHERE user = ?`
upsertUserAccessQuery = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
@ -255,17 +258,23 @@ const (
AND topic = ?
`
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
deleteExcessTokensQuery = `
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
selectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
selectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
upsertTokenQuery = `
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (user_id, token)
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
deleteTokensProvisionedQuery = `DELETE FROM user_token WHERE provisioned = 1`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
deleteExcessTokensQuery = `
DELETE FROM user_token
WHERE user_id = ?
AND (user_id, token) NOT IN (
@ -470,7 +479,7 @@ const (
role,
prefs,
sync_topic,
0,
0, -- provisioned
stats_messages,
stats_emails,
stats_calls,
@ -480,7 +489,8 @@ const (
stripe_subscription_interval,
stripe_subscription_paid_until,
stripe_subscription_cancel_at,
created, deleted
created,
deleted
FROM user_old;
DROP TABLE user_old;
@ -500,10 +510,27 @@ const (
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
DROP TABLE user_access_old;
-- Alter user_token table: Add provisioned column
ALTER TABLE user_token RENAME TO user_token_old;
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
DROP TABLE user_token_old;
-- Recreate indices
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
-- Re-enable foreign keys
PRAGMA foreign_keys=on;
@ -537,7 +564,8 @@ type Config struct {
DefaultAccess Permission // Default permission if no ACL matches
ProvisionEnabled bool // Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
Users []*User // Predefined users to create on startup
Access map[string][]*Grant // Predefined access grants to create on startup
Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
BcryptCost int // Cost of generated passwords; lowering makes testing faster
}
@ -623,15 +651,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
// CreateToken generates a random token for the given user and returns it. The token expires
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
// given user, if there are too many of them.
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
token := util.RandomLowerStringPrefix(tokenPrefix, tokenLength) // Lowercase only to support "<topic>+<token>@<domain>" email addresses
tx, err := a.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
return a.createTokenTx(tx, userID, GenerateToken(), label, expires, origin, provisioned)
})
}
func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
access := time.Now()
if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
if _, err := tx.Exec(upsertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix(), provisioned); err != nil {
return nil, err
}
rows, err := tx.Query(selectTokenCountQuery, userID)
@ -653,15 +681,13 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time, origin ne
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return &Token{
Value: token,
Label: label,
LastAccess: access,
LastOrigin: origin,
Expires: expires,
Value: token,
Label: label,
LastAccess: access,
LastOrigin: origin,
Expires: expires,
Provisioned: provisioned,
}, nil
}
@ -698,10 +724,11 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
var token, label, lastOrigin string
var lastAccess, expires int64
var provisioned bool
if !rows.Next() {
return nil, ErrTokenNotFound
}
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -711,11 +738,12 @@ func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
lastOriginIP = netip.IPv4Unspecified()
}
return &Token{
Value: token,
Label: label,
LastAccess: time.Unix(lastAccess, 0),
LastOrigin: lastOriginIP,
Expires: time.Unix(expires, 0),
Value: token,
Label: label,
LastAccess: time.Unix(lastAccess, 0),
LastOrigin: lastOriginIP,
Expires: time.Unix(expires, 0),
Provisioned: provisioned,
}, nil
}
@ -774,7 +802,7 @@ func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
phoneNumbers := make([]string, 0)
for {
phoneNumber, err := a.readPhoneNumber(rows)
if err == ErrPhoneNumberNotFound {
if errors.Is(err, ErrPhoneNumberNotFound) {
break
} else if err != nil {
return nil, err
@ -1757,6 +1785,28 @@ func (a *Manager) maybeProvisionUsersAndAccess() error {
}
}
}
// Remove and (re-)add provisioned tokens
if _, err := tx.Exec(deleteTokensProvisionedQuery); err != nil {
return err
}
for username, tokens := range a.config.Tokens {
_, exists := util.Find(a.config.Users, func(u *User) bool {
return u.Name == username
})
if !exists && username != Everyone {
return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
}
var userID string
row := tx.QueryRow(selectUserIDFromUsernameQuery, username)
if err := row.Scan(&userID); err != nil {
return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username)
}
for _, token := range tokens {
if _, err = a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil {
return err
}
}
}
return nil
})
}
@ -1974,3 +2024,22 @@ func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
}
return tx.Commit()
}
// queryTx executes a function in a transaction and returns the result. If the function
// returns an error, the transaction is rolled back.
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}

View file

@ -194,7 +194,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
require.Nil(t, err)
require.False(t, u.Deleted)
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err)
u, err = a.Authenticate("user", "pass")
@ -241,7 +241,7 @@ func TestManager_CreateToken_Only_Lower(t *testing.T) {
u, err := a.User("user")
require.Nil(t, err)
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err)
require.Equal(t, token.Value, strings.ToLower(token.Value))
}
@ -523,7 +523,7 @@ func TestManager_Token_Valid(t *testing.T) {
require.Nil(t, err)
// Create token for user
token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err)
require.NotEmpty(t, token.Value)
require.Equal(t, "some label", token.Label)
@ -586,12 +586,12 @@ func TestManager_Token_Expire(t *testing.T) {
require.Nil(t, err)
// Create tokens for user
token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err)
require.NotEmpty(t, token1.Value)
require.True(t, time.Now().Add(71*time.Hour).Unix() < token1.Expires.Unix())
token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err)
require.NotEmpty(t, token2.Value)
require.NotEqual(t, token1.Value, token2.Value)
@ -638,7 +638,7 @@ func TestManager_Token_Extend(t *testing.T) {
require.Equal(t, errNoTokenProvided, err)
// Create token for user
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err)
require.NotEmpty(t, token.Value)
@ -668,12 +668,12 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
// Create 2 tokens for phil
philTokens := make([]string, 0)
token, err := a.CreateToken(phil.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
token, err := a.CreateToken(phil.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err)
require.NotEmpty(t, token.Value)
philTokens = append(philTokens, token.Value)
token, err = a.CreateToken(phil.ID, "", time.Unix(0, 0), netip.IPv4Unspecified())
token, err = a.CreateToken(phil.ID, "", time.Unix(0, 0), netip.IPv4Unspecified(), false)
require.Nil(t, err)
require.NotEmpty(t, token.Value)
philTokens = append(philTokens, token.Value)
@ -682,7 +682,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
baseTime := time.Now().Add(24 * time.Hour)
benTokens := make([]string, 0)
for i := 0; i < 62; i++ { //
token, err := a.CreateToken(ben.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
token, err := a.CreateToken(ben.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err)
require.NotEmpty(t, token.Value)
benTokens = append(benTokens, token.Value)
@ -795,7 +795,7 @@ func TestManager_EnqueueTokenUpdate(t *testing.T) {
u, err := a.User("ben")
require.Nil(t, err)
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err)
// Queue token update
@ -1112,6 +1112,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
{TopicPattern: "secret", Permission: PermissionRead},
},
},
Tokens: map[string][]*Token{
"philuser": {
{Value: "tk_op56p8lz5bf3cxkz9je99v9oc37lo", Label: "Alerts token"},
},
},
}
a, err := NewManager(conf)
require.Nil(t, err)
@ -1123,24 +1128,29 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
users, err := a.Users()
require.Nil(t, err)
require.Len(t, users, 4)
require.Equal(t, "philadmin", users[0].Name)
require.Equal(t, RoleAdmin, users[0].Role)
require.Equal(t, "philmanual", users[1].Name)
require.Equal(t, RoleUser, users[1].Role)
require.Equal(t, "philuser", users[2].Name)
require.Equal(t, RoleUser, users[2].Role)
require.Equal(t, "*", users[3].Name)
provisionedUserID := users[2].ID // "philuser" is the provisioned user
grants, err := a.Grants("philuser")
require.Nil(t, err)
require.Equal(t, "philuser", users[2].Name)
require.Equal(t, RoleUser, users[2].Role)
require.Equal(t, 2, len(grants))
require.Equal(t, "secret", grants[0].TopicPattern)
require.Equal(t, PermissionRead, grants[0].Permission)
require.Equal(t, "stats", grants[1].TopicPattern)
require.Equal(t, PermissionReadWrite, grants[1].Permission)
require.Equal(t, "*", users[3].Name)
tokens, err := a.Tokens(provisionedUserID)
require.Nil(t, err)
require.Equal(t, 1, len(tokens))
require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc37lo", tokens[0].Value)
require.Equal(t, "Alerts token", tokens[0].Label)
require.True(t, tokens[0].Provisioned)
// Re-open the DB (second app start)
require.Nil(t, a.db.Close())
@ -1153,6 +1163,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
{TopicPattern: "secret12", Permission: PermissionRead},
},
}
conf.Tokens = map[string][]*Token{
"philuser": {
{Value: "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", Label: "Alerts token updated"},
},
}
a, err = NewManager(conf)
require.Nil(t, err)
@ -1160,30 +1175,36 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
users, err = a.Users()
require.Nil(t, err)
require.Len(t, users, 3)
require.Equal(t, "philmanual", users[0].Name)
require.Equal(t, "philuser", users[1].Name)
require.Equal(t, RoleUser, users[1].Role)
require.Equal(t, RoleUser, users[0].Role)
require.Equal(t, "*", users[2].Name)
grants, err = a.Grants("philuser")
require.Nil(t, err)
require.Equal(t, "philuser", users[1].Name)
require.Equal(t, RoleUser, users[1].Role)
require.Equal(t, 2, len(grants))
require.Equal(t, "secret12", grants[0].TopicPattern)
require.Equal(t, PermissionRead, grants[0].Permission)
require.Equal(t, "stats12", grants[1].TopicPattern)
require.Equal(t, PermissionReadWrite, grants[1].Permission)
require.Equal(t, "*", users[2].Name)
tokens, err = a.Tokens(provisionedUserID)
require.Nil(t, err)
require.Equal(t, 1, len(tokens))
require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", tokens[0].Value)
require.Equal(t, "Alerts token updated", tokens[0].Label)
require.True(t, tokens[0].Provisioned)
// Re-open the DB again (third app start)
require.Nil(t, a.db.Close())
conf.Users = []*User{}
conf.Access = map[string][]*Grant{}
conf.Tokens = map[string][]*Token{}
a, err = NewManager(conf)
require.Nil(t, err)
// Check that the provisioned users are there
// Check that the provisioned users are all gone
users, err = a.Users()
require.Nil(t, err)
require.Len(t, users, 2)
@ -1191,6 +1212,14 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
require.Equal(t, "philmanual", users[0].Name)
require.Equal(t, RoleUser, users[0].Role)
require.Equal(t, "*", users[1].Name)
grants, err = a.Grants("philuser")
require.Nil(t, err)
require.Equal(t, 0, len(grants))
tokens, err = a.Tokens(provisionedUserID)
require.Nil(t, err)
require.Equal(t, 0, len(tokens))
}
func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {

View file

@ -4,6 +4,7 @@ import (
"errors"
"github.com/stripe/stripe-go/v74"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
"net/netip"
"regexp"
"strings"
@ -59,11 +60,12 @@ type Auther interface {
// Token represents a user token, including expiry date
type Token struct {
Value string
Label string
LastAccess time.Time
LastOrigin netip.Addr
Expires time.Time
Value string
Label string
LastAccess time.Time
LastOrigin netip.Addr
Expires time.Time
Provisioned bool
}
// TokenUpdate holds information about the last access time and origin IP address of a token
@ -247,6 +249,7 @@ var (
allowedTopicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No '*'
allowedTopicPatternRegex = regexp.MustCompile(`^[-_*A-Za-z0-9]{1,64}$`) // Adds '*' for wildcards!
allowedTierRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)
allowedTokenRegex = regexp.MustCompile(`^tk_[-_A-Za-z0-9]{29}$`) // Must be tokenLength-len(tokenPrefix)
)
// AllowedRole returns true if the given role can be used for new users
@ -282,6 +285,17 @@ func AllowedPasswordHash(hash string) error {
return nil
}
// AllowedToken returns true if the given token matches the naming convention
func AllowedToken(token string) bool {
return allowedTokenRegex.MatchString(token)
}
// GenerateToken generates a new token with a prefix and a fixed length
// Lowercase only to support "<topic>+<token>@<domain>" email addresses
func GenerateToken() string {
return util.RandomLowerStringPrefix(tokenPrefix, tokenLength)
}
// Error constants used by the package
var (
ErrUnauthenticated = errors.New("unauthenticated")